diff --git a/.gitignore b/.gitignore index fc60f0e3ad..48f4fbd9be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +__pycache__ + +*.bak +*.log *.tgz **/logs/ diff --git a/Cargo.lock b/Cargo.lock index 72642469fc..7e281e9de4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1302,9 +1302,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.49" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4512b90fa68d3a9932cea5184017c5d200f5921df706d45e853537dea51508f" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" dependencies = [ "clap_builder", "clap_derive", @@ -1312,9 +1312,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.49" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0025e98baa12e766c67ba13ff4695a887a1eba19569aad00a472546795bd6730" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" dependencies = [ "anstream", "anstyle", @@ -1980,12 +1980,6 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" -[[package]] -name = "downcast-rs" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "117240f60069e65410b3ae1bb213295bd828f707b5bec6596a1afc8793ce0cbc" - [[package]] name = "dyn-clone" version = "1.0.20" @@ -2330,13 +2324,12 @@ dependencies = [ "anchor-lang", "anyhow", "async-channel 2.5.0", - "async-stream", "async-trait", - "base64 0.22.1", + "base64 0.13.1", "bb8", "borsh 0.10.4", "bs58", - "clap 4.5.49", + "clap 4.5.53", "create-address-test-program", "dashmap 6.1.0", "dotenvy", @@ -2344,7 +2337,6 @@ dependencies = [ "forester-utils", "futures", "itertools 0.14.0", - "kameo", "lazy_static", "light-account-checks", "light-batched-merkle-tree", @@ -2357,17 +2349,13 @@ dependencies = [ "light-hash-set", "light-hasher", "light-merkle-tree-metadata", - "light-merkle-tree-reference", "light-program-test", "light-prover-client", "light-registry", "light-sdk", - "light-sparse-merkle-tree", "light-system-program-anchor", "light-test-utils", "light-token-client", - "num-bigint 0.4.6", - "once_cell", "photon-api", "prometheus", "rand 0.8.5", @@ -2412,6 +2400,7 @@ dependencies = [ "light-ctoken-interface", "light-hash-set", "light-hasher", + "light-indexed-array", "light-indexed-merkle-tree", "light-merkle-tree-metadata", "light-merkle-tree-reference", @@ -2419,6 +2408,7 @@ dependencies = [ "light-registry", "light-sdk", "light-sparse-merkle-tree", + "num-bigint 0.4.6", "num-traits", "serde", "serde_json", @@ -2750,14 +2740,14 @@ checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" [[package]] name = "headers" -version = "0.3.9" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06683b93020a07e3dbcf5f8c0f6d40080d725bea7936fc01ad345c01b97dc270" +checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb" dependencies = [ - "base64 0.21.7", + "base64 0.22.1", "bytes", "headers-core", - "http 0.2.12", + "http 1.3.1", "httpdate", "mime", "sha1", @@ -2765,11 +2755,11 @@ dependencies = [ [[package]] name = "headers-core" -version = "0.2.0" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7f66481bfee273957b1f20485a4ff3362987f85b2c236580d81b4eb7a326429" +checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" dependencies = [ - "http 0.2.12", + "http 1.3.1", ] [[package]] @@ -3400,33 +3390,6 @@ dependencies = [ "serde", ] -[[package]] -name = "kameo" -version = "0.19.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c4af7638c67029fd6821d02813c3913c803784648725d4df4082c9b91d7cbb1" -dependencies = [ - "downcast-rs", - "dyn-clone", - "futures", - "kameo_macros", - "serde", - "tokio", - "tracing", -] - -[[package]] -name = "kameo_macros" -version = "0.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a13c324e2d8c8e126e63e66087448b4267e263e6cb8770c56d10a9d0d279d9e2" -dependencies = [ - "heck 0.5.0", - "proc-macro2", - "quote", - "syn 2.0.106", -] - [[package]] name = "keccak" version = "0.1.5" @@ -4010,7 +3973,7 @@ dependencies = [ "account-compression", "anchor-lang", "async-trait", - "base64 0.22.1", + "base64 0.13.1", "borsh 0.10.4", "bs58", "bytemuck", @@ -4553,24 +4516,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "multer" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01acbdc23469fd8fe07ab135923371d5f5a422fbf9c522158677c8eb15bc51c2" -dependencies = [ - "bytes", - "encoding_rs", - "futures-util", - "http 0.2.12", - "httparse", - "log", - "memchr", - "mime", - "spin", - "version_check", -] - [[package]] name = "native-tls" version = "0.2.14" @@ -5266,9 +5211,9 @@ dependencies = [ [[package]] name = "prometheus" -version = "0.13.4" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d33c28a30771f7f96db69893f78b857f7450d7e0237e9c8fc6427a81bae7ed1" +checksum = "3ca5326d8d0b950a9acd87e6a3f94745394f62e4dae1b1ee22b2bc0c394af43a" dependencies = [ "cfg-if", "fnv", @@ -5276,14 +5221,28 @@ dependencies = [ "memchr", "parking_lot", "protobuf", - "thiserror 1.0.69", + "thiserror 2.0.17", ] [[package]] name = "protobuf" -version = "2.28.0" +version = "3.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" +checksum = "d65a1d4ddae7d8b5de68153b48f6aa3bba8cb002b243dbdbc55a5afbc98f99f4" +dependencies = [ + "once_cell", + "protobuf-support", + "thiserror 1.0.69", +] + +[[package]] +name = "protobuf-support" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e36c2f31e0a47f9280fb347ef5e461ffcd2c52dd520d8e216b52f93b0b0d7d6" +dependencies = [ + "thiserror 1.0.69", +] [[package]] name = "qstring" @@ -8114,8 +8073,8 @@ dependencies = [ "thiserror 2.0.17", "tokio", "tokio-stream", - "tokio-tungstenite 0.20.1", - "tungstenite 0.20.1", + "tokio-tungstenite", + "tungstenite", "url", ] @@ -9369,12 +9328,6 @@ dependencies = [ "zeroize", ] -[[package]] -name = "spin" -version = "0.9.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" - [[package]] name = "spinning_top" version = "0.3.0" @@ -10593,7 +10546,6 @@ dependencies = [ "signal-hook-registry", "socket2 0.6.1", "tokio-macros", - "tracing", "windows-sys 0.61.2", ] @@ -10702,22 +10654,10 @@ dependencies = [ "rustls 0.21.12", "tokio", "tokio-rustls 0.24.1", - "tungstenite 0.20.1", + "tungstenite", "webpki-roots 0.25.4", ] -[[package]] -name = "tokio-tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite 0.21.0", -] - [[package]] name = "tokio-util" version = "0.6.10" @@ -11034,25 +10974,6 @@ dependencies = [ "webpki-roots 0.24.0", ] -[[package]] -name = "tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http 1.3.1", - "httparse", - "log", - "rand 0.8.5", - "sha1", - "thiserror 1.0.69", - "url", - "utf-8", -] - [[package]] name = "typenum" version = "1.19.0" @@ -11250,20 +11171,19 @@ dependencies = [ [[package]] name = "warp" -version = "0.3.7" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4378d202ff965b011c64817db11d5829506d3404edeadb61f190d111da3f231c" +checksum = "51d06d9202adc1f15d709c4f4a2069be5428aa912cc025d6f268ac441ab066b0" dependencies = [ "bytes", - "futures-channel", "futures-util", "headers", - "http 0.2.12", - "hyper 0.14.32", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", "log", "mime", "mime_guess", - "multer", "percent-encoding", "pin-project", "scoped-tls", @@ -11271,7 +11191,6 @@ dependencies = [ "serde_json", "serde_urlencoded", "tokio", - "tokio-tungstenite 0.21.0", "tokio-util 0.7.16", "tower-service", "tracing", @@ -11929,9 +11848,9 @@ dependencies = [ "anyhow", "ark-bn254 0.5.0", "ark-ff 0.5.0", - "base64 0.22.1", + "base64 0.13.1", "chrono", - "clap 4.5.49", + "clap 4.5.53", "dirs", "groth16-solana", "light-batched-merkle-tree", diff --git a/Cargo.toml b/Cargo.toml index ce6c9d3342..2b5cfe00a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,5 @@ [workspace] + members = [ "program-libs/account-checks", "program-libs/array-map", @@ -146,12 +147,14 @@ quote = "1.0" syn = { version = "2.0", features = ["visit-mut", "full"] } # Async ecosystem -futures = "0.3.17" -tokio = { version = "1.45.1", features = ["rt", "macros", "rt-multi-thread"] } +futures = "0.3.31" +tokio = { version = "1.48.0", features = ["rt", "macros", "rt-multi-thread"] } async-trait = "0.1.82" bb8 = "0.8.6" lazy_static = "1.5.0" +dashmap = "6.1.0" + # Logging log = "0.4" env_logger = "0.11" @@ -236,6 +239,7 @@ tabled = "0.20" num-traits = "0.2.19" zerocopy = { version = "0.8.25" } base64 = "0.13" +chrono = "0.4" zeroize = "=1.3.0" bitvec = { version = "1.0.1", default-features = false } # HTTP client diff --git a/forester-utils/Cargo.toml b/forester-utils/Cargo.toml index 5bbea231f9..f01848ce00 100644 --- a/forester-utils/Cargo.toml +++ b/forester-utils/Cargo.toml @@ -18,6 +18,7 @@ light-hash-set = { workspace = true } light-hasher = { workspace = true, features = ["poseidon"] } light-concurrent-merkle-tree = { workspace = true } light-indexed-merkle-tree = { workspace = true } +light-indexed-array = { workspace = true } light-compressed-account = { workspace = true, features = ["std"] } light-batched-merkle-tree = { workspace = true } light-merkle-tree-metadata = { workspace = true } @@ -49,6 +50,7 @@ anyhow = { workspace = true } tracing = { workspace = true } num-traits = { workspace = true } +num-bigint = { workspace = true } bb8 = { workspace = true } async-trait = { workspace = true } diff --git a/forester-utils/src/address_staging_tree.rs b/forester-utils/src/address_staging_tree.rs new file mode 100644 index 0000000000..786ddb6ac0 --- /dev/null +++ b/forester-utils/src/address_staging_tree.rs @@ -0,0 +1,192 @@ +use light_batched_merkle_tree::constants::DEFAULT_BATCH_ADDRESS_TREE_HEIGHT; +use light_hasher::{bigint::bigint_to_be_bytes_array, Poseidon}; +use light_prover_client::proof_types::batch_address_append::{ + get_batch_address_append_circuit_inputs, BatchAddressAppendInputs, +}; +use light_sparse_merkle_tree::{ + changelog::ChangelogEntry, indexed_changelog::IndexedChangelogEntry, SparseMerkleTree, +}; +use thiserror::Error; + +use crate::error::ForesterUtilsError; + +const HEIGHT: usize = DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize; + +#[derive(Debug, Error)] +pub enum AddressStagingTreeError { + #[error( + "Sparse tree root mismatch: computed {computed:?}[..4] != expected {expected:?}[..4] (start_index={start_index})" + )] + SparseRootMismatch { + computed: [u8; 32], + expected: [u8; 32], + start_index: usize, + }, + + #[error("Failed to build circuit inputs: {source} (next_index={next_index}, epoch={epoch}, tree={tree})")] + CircuitInputs { + source: light_prover_client::errors::ProverClientError, + next_index: usize, + epoch: u64, + tree: String, + }, + + #[error("Subtrees are required for address staging tree")] + MissingSubtrees, + + #[error("Failed to serialize new root: {0}")] + RootSerialization(String), +} + +#[derive(Clone, Debug)] +pub struct AddressBatchResult { + pub circuit_inputs: BatchAddressAppendInputs, + pub new_root: [u8; 32], + pub old_root: [u8; 32], +} + +#[derive(Clone, Debug)] +pub struct AddressStagingTree { + sparse_tree: SparseMerkleTree, + changelog: Vec>, + indexed_changelog: Vec>, + current_root: [u8; 32], + next_index: usize, +} + +impl AddressStagingTree { + pub fn new( + subtrees: [[u8; 32]; HEIGHT], + initial_root: [u8; 32], + start_index: usize, + ) -> Result { + let sparse_tree = SparseMerkleTree::::new(subtrees, start_index); + + let computed_root = sparse_tree.root(); + if computed_root != initial_root { + return Err(AddressStagingTreeError::SparseRootMismatch { + computed: computed_root, + expected: initial_root, + start_index, + } + .into()); + } + + tracing::debug!( + "AddressStagingTree::new: start_index={}, root={:?}[..4]", + start_index, + &initial_root[..4] + ); + + Ok(Self { + sparse_tree, + changelog: Vec::new(), + indexed_changelog: Vec::new(), + current_root: initial_root, + next_index: start_index, + }) + } + + pub fn from_nodes( + _nodes: &[u64], + _node_hashes: &[[u8; 32]], + initial_root: [u8; 32], + start_index: usize, + subtrees: Option<[[u8; 32]; HEIGHT]>, + ) -> Result { + match subtrees { + Some(st) => Self::new(st, initial_root, start_index), + None => Err(AddressStagingTreeError::MissingSubtrees.into()), + } + } + + pub fn current_root(&self) -> [u8; 32] { + self.current_root + } + + pub fn next_index(&self) -> usize { + self.next_index + } + + pub fn clear_changelogs(&mut self) { + self.changelog.clear(); + self.indexed_changelog.clear(); + } + + #[allow(clippy::too_many_arguments)] + pub fn process_batch( + &mut self, + addresses: &[[u8; 32]], + low_element_values: &[[u8; 32]], + low_element_next_values: &[[u8; 32]], + low_element_indices: &[u64], + low_element_next_indices: &[u64], + low_element_proofs: &[Vec<[u8; 32]>], + leaves_hashchain: [u8; 32], + zkp_batch_size: usize, + epoch: u64, + tree: &str, + ) -> Result { + let old_root = self.current_root; + let next_index = self.next_index; + + tracing::debug!( + "AddressStagingTree::process_batch: next_index={}, zkp_batch_size={}, \ + changelog_len={}, indexed_changelog_len={}, addresses_len={}, epoch={}, tree={}", + next_index, + zkp_batch_size, + self.changelog.len(), + self.indexed_changelog.len(), + addresses.len(), + epoch, + tree + ); + + let inputs = get_batch_address_append_circuit_inputs::( + next_index, + old_root, + low_element_values.to_vec(), + low_element_next_values.to_vec(), + low_element_indices.iter().map(|v| *v as usize).collect(), + low_element_next_indices + .iter() + .map(|v| *v as usize) + .collect(), + low_element_proofs.to_vec(), + addresses.to_vec(), + &mut self.sparse_tree, + leaves_hashchain, + zkp_batch_size, + &mut self.changelog, + &mut self.indexed_changelog, + ) + .map_err(|e| AddressStagingTreeError::CircuitInputs { + source: e, + next_index, + epoch, + tree: tree.to_string(), + })?; + + let new_root = bigint_to_be_bytes_array::<32>(&inputs.new_root) + .map_err(|e| AddressStagingTreeError::RootSerialization(e.to_string()))?; + + self.current_root = new_root; + self.next_index += zkp_batch_size; + + tracing::debug!( + "{:?}[..4] -> {:?}[..4] (batch_size={}, next_index={}, epoch={}, tree={})", + &old_root[..4], + &new_root[..4], + zkp_batch_size, + self.next_index, + epoch, + tree + ); + + Ok(AddressBatchResult { + circuit_inputs: inputs, + new_root, + old_root, + }) + } +} diff --git a/forester-utils/src/error.rs b/forester-utils/src/error.rs index 699e3c546e..8d7ec6bfc0 100644 --- a/forester-utils/src/error.rs +++ b/forester-utils/src/error.rs @@ -3,7 +3,7 @@ use light_client::rpc::RpcError; use light_hasher::HasherError; use thiserror::Error; -use crate::rpc_pool::PoolError; +use crate::{address_staging_tree::AddressStagingTreeError, rpc_pool::PoolError}; #[derive(Error, Debug)] pub enum ForesterUtilsError { @@ -31,4 +31,7 @@ pub enum ForesterUtilsError { #[error("error: {0}")] StagingTree(String), + + #[error(transparent)] + AddressStagingTree(#[from] AddressStagingTreeError), } diff --git a/forester-utils/src/forester_epoch.rs b/forester-utils/src/forester_epoch.rs index 5ec755f567..38c006037b 100644 --- a/forester-utils/src/forester_epoch.rs +++ b/forester-utils/src/forester_epoch.rs @@ -341,37 +341,47 @@ impl Epoch { }) } - /// creates forester account and fetches epoch account + /// Creates forester account and fetches epoch account. + /// If `epoch` is provided, registers for that specific epoch. + /// If `epoch` is None, determines the next registerable epoch automatically. pub async fn register( rpc: &mut R, protocol_config: &ProtocolConfig, authority: &Keypair, derivation: &Pubkey, + epoch: Option, ) -> Result, RpcError> { - let epoch_registration = - Self::slots_until_next_epoch_registration(rpc, protocol_config).await?; - if epoch_registration.slots_until_registration_starts > 0 - || epoch_registration.slots_until_registration_ends == 0 - { - return Ok(None); - } + let target_epoch = match epoch { + Some(e) => e, + None => { + // Auto-detect which epoch to register for + let epoch_registration = + Self::slots_until_next_epoch_registration(rpc, protocol_config).await?; + if epoch_registration.slots_until_registration_starts > 0 + || epoch_registration.slots_until_registration_ends == 0 + { + return Ok(None); + } + epoch_registration.epoch + } + }; let instruction = create_register_forester_epoch_pda_instruction( &authority.pubkey(), derivation, - epoch_registration.epoch, + target_epoch, ); let signature = rpc .create_and_send_transaction(&[instruction], &authority.pubkey(), &[authority]) .await?; rpc.confirm_transaction(signature).await?; - let epoch_pda_pubkey = get_epoch_pda_address(epoch_registration.epoch); + let epoch_pda_pubkey = get_epoch_pda_address(target_epoch); let epoch_pda = rpc .get_anchor_account::(&epoch_pda_pubkey) .await? .unwrap(); let forester_epoch_pda_pubkey = - get_forester_epoch_pda_from_authority(derivation, epoch_registration.epoch).0; + get_forester_epoch_pda_from_authority(derivation, target_epoch).0; let phases = get_epoch_phases(protocol_config, epoch_pda.epoch); Ok(Some(Self { @@ -399,11 +409,10 @@ impl Epoch { if self.phases.active.end < current_solana_slot || self.phases.active.start > current_solana_slot { - println!("current_solana_slot {:?}", current_solana_slot); - println!("registration phase {:?}", self.phases.registration); - println!("active phase {:?}", self.phases.active); - // return Err(RpcError::EpochNotActive); - panic!("TODO: throw epoch not active error"); + return Err(RpcError::AssertRpcError(format!( + "Epoch not active: current_slot={}, active_phase={}..{}", + current_solana_slot, self.phases.active.start, self.phases.active.end + ))); } let epoch_pda = rpc .get_anchor_account::(&self.epoch_pda) diff --git a/forester-utils/src/instructions/address_batch_update.rs b/forester-utils/src/instructions/address_batch_update.rs deleted file mode 100644 index a70903280a..0000000000 --- a/forester-utils/src/instructions/address_batch_update.rs +++ /dev/null @@ -1,377 +0,0 @@ -use std::{pin::Pin, sync::Arc, time::Duration}; - -use account_compression::processor::initialize_address_merkle_tree::Pubkey; -use async_stream::stream; -use futures::stream::Stream; -use light_batched_merkle_tree::{ - constants::DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, merkle_tree::InstructionDataAddressAppendInputs, -}; -use light_client::{ - indexer::{AddressQueueData, Indexer, QueueElementsV2Options}, - rpc::Rpc, -}; -use light_compressed_account::{ - hash_chain::create_hash_chain_from_slice, instruction_data::compressed_proof::CompressedProof, -}; -use light_hasher::{bigint::bigint_to_be_bytes_array, Poseidon}; -use light_prover_client::{ - proof_client::ProofClient, - proof_types::batch_address_append::get_batch_address_append_circuit_inputs, -}; -use light_sparse_merkle_tree::SparseMerkleTree; -use tracing::{debug, error, info, warn}; - -use crate::{error::ForesterUtilsError, rpc_pool::SolanaRpcPool, utils::wait_for_indexer}; - -const MAX_PHOTON_ELEMENTS_PER_CALL: usize = 1000; -const MAX_PROOFS_PER_TX: usize = 4; - -pub struct AddressUpdateConfig { - pub rpc_pool: Arc>, - pub merkle_tree_pubkey: Pubkey, - pub prover_url: String, - pub prover_api_key: Option, - pub polling_interval: Duration, - pub max_wait_time: Duration, -} - -#[allow(clippy::too_many_arguments)] -async fn stream_instruction_data<'a, R: Rpc>( - rpc_pool: Arc>, - merkle_tree_pubkey: Pubkey, - prover_url: String, - prover_api_key: Option, - polling_interval: Duration, - max_wait_time: Duration, - leaves_hash_chains: Vec<[u8; 32]>, - start_index: u64, - zkp_batch_size: u16, - mut current_root: [u8; 32], -) -> impl Stream, ForesterUtilsError>> + Send + 'a -{ - stream! { - let proof_client = Arc::new(ProofClient::with_config(prover_url, polling_interval, max_wait_time, prover_api_key)); - let max_zkp_batches_per_call = calculate_max_zkp_batches_per_call(zkp_batch_size); - let total_chunks = leaves_hash_chains.len().div_ceil(max_zkp_batches_per_call); - - let mut next_queue_index: Option = None; - - for chunk_idx in 0..total_chunks { - let chunk_start = chunk_idx * max_zkp_batches_per_call; - let chunk_end = std::cmp::min(chunk_start + max_zkp_batches_per_call, leaves_hash_chains.len()); - let chunk_hash_chains = &leaves_hash_chains[chunk_start..chunk_end]; - - let elements_for_chunk = chunk_hash_chains.len() * zkp_batch_size as usize; - - { - if chunk_idx > 0 { - debug!("Waiting for indexer to sync before fetching chunk {} data", chunk_idx); - } - let connection = rpc_pool.get_connection().await?; - wait_for_indexer(&*connection).await?; - if chunk_idx > 0 { - debug!("Indexer synced, proceeding with chunk {} fetch", chunk_idx); - } - } - - let address_queue = { - let mut connection = rpc_pool.get_connection().await?; - let indexer = connection.indexer_mut()?; - debug!( - "Requesting {} addresses from Photon for chunk {} with start_queue_index={:?}", - elements_for_chunk, chunk_idx, next_queue_index - ); - let options = QueueElementsV2Options::default() - .with_address_queue(next_queue_index, Some(elements_for_chunk as u16)); - match indexer - .get_queue_elements(merkle_tree_pubkey.to_bytes(), options, None) - .await - { - Ok(response) => match response.value.address_queue { - Some(queue) => queue, - None => { - yield Err(ForesterUtilsError::Indexer( - "No address queue data in response".into(), - )); - return; - } - }, - Err(e) => { - yield Err(ForesterUtilsError::Indexer(format!( - "Failed to get queue elements: {}", - e - ))); - return; - } - } - }; - - debug!( - "Photon response for chunk {}: received {} addresses, start_index={}, first_queue_index={:?}, last_queue_index={:?}", - chunk_idx, - address_queue.addresses.len(), - address_queue.start_index, - address_queue.queue_indices.first(), - address_queue.queue_indices.last() - ); - - if let Some(last_queue_index) = address_queue.queue_indices.last() { - next_queue_index = Some(last_queue_index + 1); - debug!( - "Setting next_queue_index={} for chunk {}", - next_queue_index.unwrap(), - chunk_idx + 1 - ); - } - - if chunk_idx == 0 { - if address_queue.addresses.is_empty() { - yield Err(ForesterUtilsError::Indexer( - "No addresses found in indexer response".into(), - )); - return; - } - if address_queue.initial_root != current_root { - warn!("Indexer root does not match on-chain root"); - yield Err(ForesterUtilsError::Indexer( - "Indexer root does not match on-chain root".into(), - )); - return; - } - } - - let (all_inputs, new_current_root) = match get_all_circuit_inputs_for_chunk( - chunk_hash_chains, - &address_queue, - zkp_batch_size, - chunk_start, - start_index, - current_root, - ) { - Ok((inputs, new_root)) => (inputs, new_root), - Err(e) => { - yield Err(e); - return; - } - }; - current_root = new_current_root; - - info!("Generating {} zk proofs for batch_address chunk {}", all_inputs.len(), chunk_idx + 1); - - let proof_futures: Vec<_> = all_inputs.into_iter().enumerate().map(|(i, inputs)| { - let client = Arc::clone(&proof_client); - async move { - let result = client.generate_batch_address_append_proof(inputs).await; - (i, result) - } - }).collect(); - - let proof_results = futures::future::join_all(proof_futures).await; - - let mut proof_buffer = Vec::new(); - for (idx, result) in proof_results { - match result { - Ok((compressed_proof, new_root)) => { - let instruction_data = InstructionDataAddressAppendInputs { - new_root, - compressed_proof: CompressedProof { - a: compressed_proof.a, - b: compressed_proof.b, - c: compressed_proof.c, - }, - }; - proof_buffer.push(instruction_data); - - if proof_buffer.len() >= MAX_PROOFS_PER_TX { - yield Ok(proof_buffer.clone()); - proof_buffer.clear(); - } - }, - Err(e) => { - error!("Address proof failed to generate at index {}: {:?}", idx, e); - yield Err(ForesterUtilsError::Prover(format!( - "Address proof generation failed at batch {} in chunk {}: {}", - idx, chunk_idx, e - ))); - return; - } - } - } - - // Yield any remaining proofs - if !proof_buffer.is_empty() { - yield Ok(proof_buffer); - } - } - } -} - -fn calculate_max_zkp_batches_per_call(batch_size: u16) -> usize { - std::cmp::max(1, MAX_PHOTON_ELEMENTS_PER_CALL / batch_size as usize) -} - -fn get_all_circuit_inputs_for_chunk( - chunk_hash_chains: &[[u8; 32]], - address_queue: &AddressQueueData, - batch_size: u16, - chunk_start_idx: usize, - global_start_index: u64, - mut current_root: [u8; 32], -) -> Result< - ( - Vec, - [u8; 32], - ), - ForesterUtilsError, -> { - let subtrees_array: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = - address_queue.subtrees.clone().try_into().map_err(|_| { - ForesterUtilsError::Prover("Failed to convert subtrees to array".into()) - })?; - - let mut sparse_merkle_tree = - SparseMerkleTree::::new( - subtrees_array, - global_start_index as usize + (chunk_start_idx * batch_size as usize), - ); - - let mut all_inputs = Vec::new(); - let mut changelog = Vec::new(); - let mut indexed_changelog = Vec::new(); - - for (batch_idx, leaves_hash_chain) in chunk_hash_chains.iter().enumerate() { - let start_idx = batch_idx * batch_size as usize; - let end_idx = start_idx + batch_size as usize; - - let addresses_len = address_queue.addresses.len(); - if start_idx >= addresses_len { - return Err(ForesterUtilsError::Indexer(format!( - "Insufficient addresses: batch {} requires start_idx {} but only {} addresses available", - batch_idx, start_idx, addresses_len - ))); - } - let safe_end_idx = std::cmp::min(end_idx, addresses_len); - if safe_end_idx - start_idx != batch_size as usize { - return Err(ForesterUtilsError::Indexer(format!( - "Insufficient addresses: batch {} requires {} addresses (indices {}..{}) but only {} available", - batch_idx, batch_size, start_idx, end_idx, safe_end_idx - start_idx - ))); - } - - let batch_addresses: Vec<[u8; 32]> = - address_queue.addresses[start_idx..safe_end_idx].to_vec(); - - // Check that we have enough low element data - let low_elements_len = address_queue.low_element_values.len(); - if start_idx >= low_elements_len { - return Err(ForesterUtilsError::Indexer(format!( - "Insufficient low element data: batch {} requires start_idx {} but only {} elements available", - batch_idx, start_idx, low_elements_len - ))); - } - let safe_low_end_idx = std::cmp::min(end_idx, low_elements_len); - if safe_low_end_idx - start_idx != batch_size as usize { - return Err(ForesterUtilsError::Indexer(format!( - "Insufficient low element data: batch {} requires {} elements (indices {}..{}) but only {} available", - batch_idx, batch_size, start_idx, end_idx, safe_low_end_idx - start_idx - ))); - } - - let low_element_values: Vec<[u8; 32]> = - address_queue.low_element_values[start_idx..safe_low_end_idx].to_vec(); - let low_element_next_values: Vec<[u8; 32]> = - address_queue.low_element_next_values[start_idx..safe_low_end_idx].to_vec(); - let low_element_indices: Vec = address_queue.low_element_indices - [start_idx..safe_low_end_idx] - .iter() - .map(|&x| x as usize) - .collect(); - let low_element_next_indices: Vec = address_queue.low_element_next_indices - [start_idx..safe_low_end_idx] - .iter() - .map(|&x| x as usize) - .collect(); - let low_element_proofs: Vec> = - address_queue.low_element_proofs[start_idx..safe_low_end_idx].to_vec(); - - let computed_hash_chain = create_hash_chain_from_slice(&batch_addresses)?; - if computed_hash_chain != *leaves_hash_chain { - return Err(ForesterUtilsError::Prover( - "Addresses hash chain does not match".into(), - )); - } - - let adjusted_start_index = global_start_index as usize - + (chunk_start_idx * batch_size as usize) - + (batch_idx * batch_size as usize); - - let inputs = get_batch_address_append_circuit_inputs( - adjusted_start_index, - current_root, - low_element_values, - low_element_next_values, - low_element_indices, - low_element_next_indices, - low_element_proofs, - batch_addresses, - &mut sparse_merkle_tree, - *leaves_hash_chain, - batch_size as usize, - &mut changelog, - &mut indexed_changelog, - ) - .map_err(|e| ForesterUtilsError::Prover(format!("Failed to get circuit inputs: {}", e)))?; - - current_root = bigint_to_be_bytes_array::<32>(&inputs.new_root)?; - all_inputs.push(inputs); - } - - Ok((all_inputs, current_root)) -} - -pub async fn get_address_update_instruction_stream<'a, R: Rpc>( - config: AddressUpdateConfig, - merkle_tree_data: crate::ParsedMerkleTreeData, -) -> Result< - ( - Pin< - Box< - dyn Stream< - Item = Result, ForesterUtilsError>, - > + Send - + 'a, - >, - >, - u16, - ), - ForesterUtilsError, -> { - let (current_root, leaves_hash_chains, start_index, zkp_batch_size) = ( - merkle_tree_data.current_root, - merkle_tree_data.leaves_hash_chains, - // merkle_tree_data.batch_start_index, - merkle_tree_data.next_index, - merkle_tree_data.zkp_batch_size, - ); - - if leaves_hash_chains.is_empty() { - debug!("No hash chains to process for address update, returning empty stream."); - return Ok((Box::pin(futures::stream::empty()), zkp_batch_size)); - } - - let stream = stream_instruction_data( - config.rpc_pool, - config.merkle_tree_pubkey, - config.prover_url, - config.prover_api_key, - config.polling_interval, - config.max_wait_time, - leaves_hash_chains, - start_index, - zkp_batch_size, - current_root, - ) - .await; - - Ok((Box::pin(stream), zkp_batch_size)) -} diff --git a/forester-utils/src/instructions/mod.rs b/forester-utils/src/instructions/mod.rs index 8c70c1ca69..4173dcc027 100644 --- a/forester-utils/src/instructions/mod.rs +++ b/forester-utils/src/instructions/mod.rs @@ -1,4 +1,3 @@ -pub mod address_batch_update; pub mod create_account; pub use create_account::create_account_instruction; diff --git a/forester-utils/src/lib.rs b/forester-utils/src/lib.rs index 13bb1db513..8cb9864ebe 100644 --- a/forester-utils/src/lib.rs +++ b/forester-utils/src/lib.rs @@ -4,6 +4,7 @@ pub mod account_zero_copy; pub mod address_merkle_tree_config; +pub mod address_staging_tree; pub mod error; pub mod forester_epoch; pub mod instructions; diff --git a/forester-utils/src/rate_limiter.rs b/forester-utils/src/rate_limiter.rs index 7eda3126a2..546f4726a1 100644 --- a/forester-utils/src/rate_limiter.rs +++ b/forester-utils/src/rate_limiter.rs @@ -46,8 +46,7 @@ impl RateLimiter { } pub async fn acquire_with_wait(&self) { - let _start = self.governor.until_ready().await; - tokio::time::sleep(Duration::from_millis(1)).await; + self.governor.until_ready().await; } } diff --git a/forester-utils/src/staging_tree.rs b/forester-utils/src/staging_tree.rs index cefd09407a..604853d347 100644 --- a/forester-utils/src/staging_tree.rs +++ b/forester-utils/src/staging_tree.rs @@ -5,8 +5,6 @@ use tracing::debug; use crate::error::ForesterUtilsError; -pub const TREE_HEIGHT: usize = 32; - /// Result of a batch update operation on a staging tree. #[derive(Clone, Debug)] pub struct BatchUpdateResult { @@ -110,6 +108,7 @@ impl StagingTree { Ok(()) } + #[allow(clippy::too_many_arguments)] pub fn process_batch_updates( &mut self, leaf_indices: &[u64], @@ -117,6 +116,8 @@ impl StagingTree { batch_type: BatchType, batch_idx: usize, batch_seq: u64, + epoch: u64, + tree: &str, ) -> Result { if leaf_indices.len() != new_leaves.len() { return Err(ForesterUtilsError::StagingTree(format!( @@ -169,11 +170,13 @@ impl StagingTree { self.current_root = new_root; debug!( - "{} batch {} root transition: {:?}[..4] -> {:?}[..4]", + "{} batch {} root transition: {:?}[..4] -> {:?}[..4] (epoch={}, tree={})", batch_type, batch_idx, &old_root[..4], - &new_root[..4] + &new_root[..4], + epoch, + tree ); Ok(BatchUpdateResult { @@ -205,19 +208,21 @@ impl StagingTree { node_hashes: &[[u8; 32]], initial_root: [u8; 32], root_seq: u64, + height: usize, ) -> Result { debug!( - "StagingTree::new: {} leaves, {} deduplicated nodes, initial_root={:?}, root_seq={}", + "StagingTree::new: {} leaves, {} deduplicated nodes, initial_root={:?}, root_seq={}, height={}", leaves.len(), nodes.len(), &initial_root, - root_seq + root_seq, + height ); - let mut tree = MerkleTree::::new(TREE_HEIGHT, 0); + let mut tree = MerkleTree::::new(height, 0); for (&node_index, &node_hash) in nodes.iter().zip(node_hashes.iter()) { // Skip nodes at root level - root is stored separately in tree.roots let level = (node_index >> 56) as usize; - if level == TREE_HEIGHT { + if level == height { continue; } tree.insert_node(node_index, node_hash).map_err(|e| { diff --git a/forester/.gitignore b/forester/.gitignore index 5f02644564..f6072a93ee 100644 --- a/forester/.gitignore +++ b/forester/.gitignore @@ -2,5 +2,6 @@ logs /target .idea .env +.env.devnet *.json !package.json diff --git a/forester/Cargo.toml b/forester/Cargo.toml index 1aa74e6168..ad9c14484e 100644 --- a/forester/Cargo.toml +++ b/forester/Cargo.toml @@ -6,7 +6,7 @@ publish = false [dependencies] anchor-lang = { workspace = true } -clap = { version = "4.5.27", features = ["derive", "env"] } +clap = { version = "4.5.53", features = ["derive", "env"] } solana-sdk = { workspace = true } solana-client = { workspace = true } solana-account-decoder = { workspace = true } @@ -18,14 +18,12 @@ light-compressed-account = { workspace = true, features = ["std"] } light-system-program-anchor = { workspace = true, features = ["cpi"] } light-hash-set = { workspace = true, features = ["solana"] } light-hasher = { workspace = true, features = ["poseidon"] } -light-merkle-tree-reference = { workspace = true } light-prover-client = { workspace = true } light-registry = { workspace = true } photon-api = { workspace = true } forester-utils = { workspace = true } light-client = { workspace = true, features = ["v2"] } light-merkle-tree-metadata = { workspace = true } -light-sparse-merkle-tree = { workspace = true } light-sdk = { workspace = true, features = ["anchor"] } light-program-test = { workspace = true } light-compressible = { workspace = true } @@ -34,17 +32,15 @@ light-ctoken-sdk = { workspace = true } solana-rpc-client-api = { workspace = true } solana-transaction-status = { workspace = true } bb8 = { workspace = true } -base64 = "0.22" # make workspace dep - -serde_json = "1.0" -serde = { version = "1.0", features = ["derive"] } -tokio = { version = "1", features = ["full"] } +base64 = { workspace = true } +serde_json = { workspace = true } +serde = { workspace = true } +tokio = { workspace = true, features = ["full"] } reqwest = { workspace = true, features = ["json", "rustls-tls", "blocking"] } -futures = "0.3.31" -async-stream = "0.3" +futures = { workspace = true } thiserror = { workspace = true } borsh = { workspace = true } -bs58 = "0.5.1" +bs58 = { workspace = true } env_logger = { workspace = true } async-trait = { workspace = true } tracing = { workspace = true } @@ -52,16 +48,13 @@ tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } anyhow = { workspace = true } -prometheus = "0.13" -lazy_static = "1.4" -warp = "0.3" -dashmap = "6.1.0" -scopeguard = "1.2.0" -itertools = "0.14.0" -num-bigint = { workspace = true } -kameo = "0.19" -once_cell = "1.21.3" -async-channel = "2.3" +prometheus = "0.14" +lazy_static = { workspace = true } +warp = "0.4" +dashmap = { workspace = true } +scopeguard = "1.2" +itertools = "0.14" +async-channel = "2.5" solana-pubkey = { workspace = true } [dev-dependencies] @@ -69,10 +62,8 @@ serial_test = { workspace = true } light-prover-client = { workspace = true, features = ["devenv"] } light-test-utils = { workspace = true } light-program-test = { workspace = true, features = ["devenv"] } -light-batched-merkle-tree = { workspace = true, features = ["test-only"] } light-token-client = { workspace = true } dotenvy = "0.15" light-compressed-token = { workspace = true } -light-ctoken-sdk = { workspace = true } rand = { workspace = true } create-address-test-program = { workspace = true } diff --git a/forester/README.md b/forester/README.md index f74be7c1cb..8ad0332d8f 100644 --- a/forester/README.md +++ b/forester/README.md @@ -70,6 +70,7 @@ Control transaction batching and concurrency: | `--transaction-max-concurrent-batches` | 20 | Maximum concurrent transaction batches | | `--cu-limit` | 1000000 | Compute unit limit per transaction | | `--enable-priority-fees` | false | Enable dynamic priority fee calculation | +| `--enable-compressible` | false | Enable compressible account tracking and compression (requires `--ws-rpc-url`) | #### Example diff --git a/forester/scripts/plot.py b/forester/scripts/plot.py new file mode 100644 index 0000000000..15749a932d --- /dev/null +++ b/forester/scripts/plot.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +import argparse +from pathlib import Path + +import pandas as pd +import matplotlib.pyplot as plt +import re + + +# Match ISO8601 timestamp like 2025-12-04T12:21:58.990364Z anywhere in the line +TS_RE = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z") + + +def parse_log(path: Path) -> pd.DataFrame: + """ + Parse a Forester tx_sender log and return a DataFrame with: + index: timestamp (UTC) + column: proofs (ixs value) + + Expected line example: + 2025-12-04T12:21:58.990364Z INFO forester::processor::v2::tx_sender: \ + tx sent: ... type=AddressAppend ixs=4 root=[...] seq=0..3 epoch=26 + """ + rows: list[tuple[str, int]] = [] + + with path.open("r", encoding="utf-8") as f: + for line in f: + if "ixs=" not in line: + continue + + # --- timestamp via regex (robust against ANSI codes, prefixes, etc.) --- + m = TS_RE.search(line) + if not m: + continue + ts = m.group(0) + + # --- parse integer after "ixs=" --- + ix_pos = line.find("ixs=") + if ix_pos == -1: + continue + rest = line[ix_pos + len("ixs=") :] # e.g. "4 root=[...] ..." + num_str = "" + for ch in rest: + if ch.isdigit(): + num_str += ch + else: + break + if not num_str: + continue + + proofs = int(num_str) + rows.append((ts, proofs)) + + if not rows: + raise RuntimeError( + "No lines with 'ixs=' could be parsed. " + "Make sure the log is from forester::processor::v2::tx_sender." + ) + + df = pd.DataFrame(rows, columns=["timestamp", "proofs"]) + df["timestamp"] = pd.to_datetime(df["timestamp"], utc=True, errors="raise") + df = df.sort_values("timestamp").set_index("timestamp") + + return df + + +def make_plots( + df: pd.DataFrame, + rolling_window: int = 3, + show: bool = True, + out: Path | None = None, +): + """ + Create two plots: + 1) proofs per tx over time + 2) proofs per minute + rolling average + + rolling_window is in minutes. + """ + # --- per-tx plot --------------------------------------------------------- + plt.figure(figsize=(10, 4)) + plt.plot(df.index, df["proofs"], marker="o") + plt.title("Proofs per Transaction Over Time") + plt.xlabel("Time") + plt.ylabel("ixs (proofs per tx)") + plt.xticks(rotation=45) + plt.tight_layout() + + if out is not None: + per_tx_path = out.with_suffix(".per_tx.png") + plt.savefig(per_tx_path, dpi=150) + print(f"Saved per-tx plot to: {per_tx_path}") + if show: + plt.show() + + # --- per-minute aggregation + rolling average --------------------------- + # Sum proofs per minute + per_min = df["proofs"].resample("1T").sum() + + # Rolling average over N minutes (trailing window) + rolling = per_min.rolling(window=rolling_window, min_periods=1).mean() + + duration_min = (df.index.max() - df.index.min()).total_seconds() / 60.0 + total_proofs = df["proofs"].sum() + avg_per_min = total_proofs / duration_min if duration_min > 0 else float("nan") + + print(f"Total proofs: {total_proofs}") + print(f"Duration: {duration_min:.2f} minutes") + print(f"Average throughput: {avg_per_min:.2f} proofs/min") + print(f"Rolling window: {rolling_window} minutes") + + plt.figure(figsize=(10, 4)) + plt.plot(per_min.index, per_min.values, marker="o", label="Proofs per minute") + plt.plot( + rolling.index, + rolling.values, + linestyle="--", + marker="x", + label=f"Rolling avg ({rolling_window}-min)", + ) + plt.title("Proof Throughput and Rolling Average") + plt.xlabel("Time") + plt.ylabel("Proofs per minute") + plt.xticks(rotation=45) + plt.legend() + plt.tight_layout() + + if out is not None: + per_min_path = out.with_suffix(".per_min.png") + plt.savefig(per_min_path, dpi=150) + print(f"Saved per-minute plot to: {per_min_path}") + if show: + plt.show() + + +def main(): + parser = argparse.ArgumentParser( + description="Parse Forester tx_sender logs and plot proof throughput." + ) + parser.add_argument("logfile", type=Path, help="Path to log file, e.g. tx.log") + parser.add_argument( + "--rolling-window", + type=int, + default=3, + help="Rolling average window in minutes (default: 3).", + ) + parser.add_argument( + "--no-show", + action="store_true", + help="Do not display plots interactively, only save to files (if --out is given).", + ) + parser.add_argument( + "--out", + type=Path, + default=None, + help="Base path to save PNGs (without extension), e.g. ./proofs", + ) + + args = parser.parse_args() + + df = parse_log(args.logfile) + # Debug: uncomment if you want to quickly see parsed data + # print(df.head(), df.tail(), df.shape) + + make_plots( + df, + rolling_window=args.rolling_window, + show=not args.no_show, + out=args.out, + ) + + +if __name__ == "__main__": + main() + diff --git a/forester/scripts/plot_enhanced.py b/forester/scripts/plot_enhanced.py new file mode 100755 index 0000000000..e96f59fad4 --- /dev/null +++ b/forester/scripts/plot_enhanced.py @@ -0,0 +1,791 @@ +#!/usr/bin/env python3 +""" +Enhanced Forester Performance Analysis Tool + +Parses forester logs and generates comprehensive performance visualizations including: +- Proof round-trip latency distribution and timeline +- Transaction throughput with gap analysis +- Pipeline utilization (proof requests vs completions) +- Queue drain rates and bottleneck identification +- Indexer sync wait analysis +""" + +import argparse +import re +from pathlib import Path +from dataclasses import dataclass +from typing import Optional +from datetime import datetime, timedelta + +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +import numpy as np + +# Regex patterns +TS_RE = re.compile(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z") +ROUND_TRIP_RE = re.compile(r"round_trip=(\d+)ms") +PROOF_TIME_RE = re.compile(r"proof=(\d+)ms") +IXS_RE = re.compile(r"ixs=(\d+)") +TYPE_RE = re.compile(r"type=(\w+)") +QUEUE_ITEMS_RE = re.compile(r"(\d+)\s+items") +BATCH_LIMIT_RE = re.compile(r"Queue size (\d+) would produce (\d+) batches, limiting to (\d+)") +CIRCUIT_RE = re.compile(r"circuit type:\s*(\w+)") + + +@dataclass +class ProofEvent: + timestamp: datetime + round_trip_ms: int + seq: Optional[int] = None + job_id: Optional[str] = None + proof_type: Optional[str] = None + proof_ms: Optional[int] = None # Pure proof generation time from prover server + + @property + def queue_wait_ms(self) -> Optional[int]: + """Time spent waiting in queue (round_trip - proof).""" + if self.proof_ms is not None: + return self.round_trip_ms - self.proof_ms + return None + + +@dataclass +class TxEvent: + timestamp: datetime + ixs: int + tx_type: str + tx_hash: str + tree: Optional[str] = None + e2e_ms: Optional[int] = None # End-to-end latency: proof submit → tx sent + + +@dataclass +class ProofRequest: + timestamp: datetime + circuit_type: str + + +@dataclass +class BottleneckEvent: + timestamp: datetime + event_type: str # 'indexer_sync', 'batch_limit', 'idle' + details: str + + +def parse_timestamp(line: str) -> Optional[datetime]: + """Extract timestamp from log line.""" + m = TS_RE.search(line) + if m: + return datetime.fromisoformat(m.group(0).replace('Z', '+00:00')) + return None + + +def parse_log(path: Path) -> dict: + """Parse forester log and extract all performance-relevant events.""" + proof_completions: list[ProofEvent] = [] + tx_events: list[TxEvent] = [] + proof_requests: list[ProofRequest] = [] + bottlenecks: list[BottleneckEvent] = [] + queue_updates: list[tuple[datetime, str, int]] = [] # (ts, tree, items) + job_types: dict[str, str] = {} # job_id -> proof_type + + with path.open("r", encoding="utf-8", errors="replace") as f: + for line in f: + ts = parse_timestamp(line) + if not ts: + continue + + # Submitted proof job (to map job_id -> type) + if "Submitted proof job" in line: + m = re.search(r'type=(\w+)\s+job_id=([a-f0-9-]+)', line) + if m: + job_types[m.group(2)] = m.group(1) + + # Proof completions + if "Proof completed" in line: + m = ROUND_TRIP_RE.search(line) + job_m = re.search(r'job_id=([a-f0-9-]+)', line) + proof_m = PROOF_TIME_RE.search(line) + if m: + job_id = job_m.group(1) if job_m else None + proof_type = job_types.get(job_id) if job_id else None + proof_ms = int(proof_m.group(1)) if proof_m else None + proof_completions.append(ProofEvent( + timestamp=ts, + round_trip_ms=int(m.group(1)), + job_id=job_id, + proof_type=proof_type, + proof_ms=proof_ms + )) + + # TX sent + elif "tx sent:" in line: + ixs_m = IXS_RE.search(line) + type_m = TYPE_RE.search(line) + if ixs_m and type_m: + # Extract tx hash (first base58 string after "tx sent:") + hash_start = line.find("tx sent:") + 9 + hash_end = line.find(" ", hash_start) + tx_hash = line[hash_start:hash_end] if hash_end > hash_start else "" + + # Extract tree pubkey (new format) + tree_m = re.search(r'tree=(\w+)', line) + tree = tree_m.group(1) if tree_m else None + + # Extract timing info (new format: e2e=ms) + tx_e2e_m = re.search(r'e2e=(\d+)ms', line) + + tx_events.append(TxEvent( + timestamp=ts, + ixs=int(ixs_m.group(1)), + tx_type=type_m.group(1), + tx_hash=tx_hash, + tree=tree, + e2e_ms=int(tx_e2e_m.group(1)) if tx_e2e_m else None, + )) + + # Proof requests + elif "Submitting async proof request" in line: + m = CIRCUIT_RE.search(line) + if m: + proof_requests.append(ProofRequest( + timestamp=ts, + circuit_type=m.group(1) + )) + + # Indexer sync waits + elif "waiting for indexer sync" in line: + bottlenecks.append(BottleneckEvent( + timestamp=ts, + event_type="indexer_sync", + details=line.strip() + )) + + # Batch limiting + elif "would produce" in line and "limiting to" in line: + m = BATCH_LIMIT_RE.search(line) + if m: + bottlenecks.append(BottleneckEvent( + timestamp=ts, + event_type="batch_limit", + details=f"Queue {m.group(1)} -> {m.group(2)} batches, limited to {m.group(3)}" + )) + + # Queue updates + elif "Routed update to tree" in line: + m = QUEUE_ITEMS_RE.search(line) + if m: + # Extract tree name + tree_start = line.find("tree ") + 5 + tree_end = line.find(":", tree_start) + tree = line[tree_start:tree_end] if tree_end > tree_start else "unknown" + queue_updates.append((ts, tree, int(m.group(1)))) + + return { + "proof_completions": proof_completions, + "tx_events": tx_events, + "proof_requests": proof_requests, + "bottlenecks": bottlenecks, + "queue_updates": queue_updates + } + + +def plot_latency_distribution(proof_completions: list[ProofEvent], ax): + """Plot round-trip latency histogram with percentiles.""" + if not proof_completions: + ax.text(0.5, 0.5, "No proof data", ha='center', va='center') + return + + latencies = [p.round_trip_ms for p in proof_completions] + + # Histogram - ensure bins are monotonically increasing + max_latency = max(latencies) if latencies else 1000 + base_bins = [0, 500, 1000, 2000, 5000, 10000, 20000] + # Only keep bins smaller than max_latency, then add final bin + bins = [b for b in base_bins if b < max_latency] + [max_latency + 1000] + if len(bins) < 2: + bins = [0, max_latency + 1000] + ax.hist(latencies, bins=bins, edgecolor='black', alpha=0.7, color='steelblue') + + # Percentile lines + p50 = np.percentile(latencies, 50) + p95 = np.percentile(latencies, 95) + p99 = np.percentile(latencies, 99) + + ax.axvline(p50, color='green', linestyle='--', linewidth=2, label=f'p50: {p50:.0f}ms') + ax.axvline(p95, color='orange', linestyle='--', linewidth=2, label=f'p95: {p95:.0f}ms') + ax.axvline(p99, color='red', linestyle='--', linewidth=2, label=f'p99: {p99:.0f}ms') + + ax.set_xlabel('Round-trip Latency (ms)') + ax.set_ylabel('Count') + ax.set_title(f'Proof Latency Distribution (n={len(latencies)})') + ax.legend() + ax.set_xscale('log') + + +def plot_latency_timeline(proof_completions: list[ProofEvent], ax): + """Plot latency over time with color-coded severity.""" + if not proof_completions: + ax.text(0.5, 0.5, "No proof data", ha='center', va='center') + return + + timestamps = [p.timestamp for p in proof_completions] + latencies = [p.round_trip_ms for p in proof_completions] + + # Color by latency bucket + colors = [] + for lat in latencies: + if lat < 1000: + colors.append('green') + elif lat < 5000: + colors.append('yellow') + elif lat < 10000: + colors.append('orange') + else: + colors.append('red') + + ax.scatter(timestamps, latencies, c=colors, alpha=0.6, s=20) + ax.set_xlabel('Time') + ax.set_ylabel('Latency (ms)') + ax.set_title('Proof Latency Over Time') + ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) + ax.tick_params(axis='x', rotation=45) + + +def plot_throughput_gaps(tx_events: list[TxEvent], ax): + """Plot transaction throughput with gap highlighting.""" + if not tx_events: + ax.text(0.5, 0.5, "No TX data", ha='center', va='center') + return + + timestamps = [t.timestamp for t in tx_events] + ixs = [t.ixs for t in tx_events] + + # Calculate gaps + gaps = [] + for i in range(1, len(timestamps)): + gap = (timestamps[i] - timestamps[i-1]).total_seconds() + gaps.append((timestamps[i-1], timestamps[i], gap)) + + # Plot proofs + ax.bar(timestamps, ixs, width=0.0001, color='steelblue', alpha=0.8) + + # Highlight large gaps (> 10s) + for start, end, gap in gaps: + if gap > 10: + ax.axvspan(start, end, alpha=0.3, color='red') + mid = start + (end - start) / 2 + ax.annotate(f'{gap:.0f}s', xy=(mid, max(ixs)*0.9), fontsize=8, ha='center', color='red') + + ax.set_xlabel('Time') + ax.set_ylabel('Proofs per TX') + ax.set_title('Transaction Throughput with Gaps (red = >10s gap)') + ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) + ax.tick_params(axis='x', rotation=45) + + +def plot_pipeline_utilization(proof_requests: list[ProofRequest], + proof_completions: list[ProofEvent], + ax): + """Plot proof request vs completion rates to show pipeline depth.""" + if not proof_requests or not proof_completions: + ax.text(0.5, 0.5, "No pipeline data", ha='center', va='center') + return + + # Resample to 1-second bins + req_times = [p.timestamp for p in proof_requests] + comp_times = [p.timestamp for p in proof_completions] + + all_times = req_times + comp_times + if not all_times: + return + + min_t = min(all_times) + max_t = max(all_times) + + # Create time bins (1 second) + bins = [] + t = min_t + while t <= max_t: + bins.append(t) + t += timedelta(seconds=1) + + req_counts = np.zeros(len(bins) - 1) + comp_counts = np.zeros(len(bins) - 1) + + for rt in req_times: + for i in range(len(bins) - 1): + if bins[i] <= rt < bins[i + 1]: + req_counts[i] += 1 + break + + for ct in comp_times: + for i in range(len(bins) - 1): + if bins[i] <= ct < bins[i + 1]: + comp_counts[i] += 1 + break + + bin_centers = [bins[i] + (bins[i+1] - bins[i]) / 2 for i in range(len(bins) - 1)] + + ax.fill_between(bin_centers, req_counts, alpha=0.5, label='Requests', color='blue') + ax.fill_between(bin_centers, comp_counts, alpha=0.5, label='Completions', color='green') + ax.plot(bin_centers, np.cumsum(req_counts) - np.cumsum(comp_counts), + color='red', linewidth=2, label='In-flight (cumulative)') + + ax.set_xlabel('Time') + ax.set_ylabel('Count per second') + ax.set_title('Pipeline Utilization (Requests vs Completions)') + ax.legend() + ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) + ax.tick_params(axis='x', rotation=45) + + +def plot_bottleneck_timeline(bottlenecks: list[BottleneckEvent], ax): + """Plot bottleneck events on a timeline.""" + if not bottlenecks: + ax.text(0.5, 0.5, "No bottleneck events detected", ha='center', va='center') + return + + event_types = {"indexer_sync": 0, "batch_limit": 1} + colors = {"indexer_sync": "red", "batch_limit": "orange"} + + for b in bottlenecks: + y = event_types.get(b.event_type, 2) + ax.scatter([b.timestamp], [y], c=colors.get(b.event_type, 'gray'), s=50, alpha=0.7) + + ax.set_yticks([0, 1]) + ax.set_yticklabels(['Indexer Sync Wait', 'Batch Limit']) + ax.set_xlabel('Time') + ax.set_title(f'Bottleneck Events (n={len(bottlenecks)})') + ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) + ax.tick_params(axis='x', rotation=45) + + +def plot_tx_type_breakdown(tx_events: list[TxEvent], ax): + """Pie chart of transaction types.""" + if not tx_events: + ax.text(0.5, 0.5, "No TX data", ha='center', va='center') + return + + type_counts = {} + for t in tx_events: + type_counts[t.tx_type] = type_counts.get(t.tx_type, 0) + 1 + + labels = list(type_counts.keys()) + sizes = list(type_counts.values()) + + ax.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=90) + ax.set_title('Transaction Type Distribution') + + +def plot_tx_by_tree(tx_events: list[TxEvent], ax): + """Bar chart showing TX count and timing by tree.""" + txs_with_tree = [t for t in tx_events if t.tree] + + if not txs_with_tree: + ax.text(0.5, 0.5, "No tree data in TXs\n(old log format)", ha='center', va='center') + return + + # Aggregate by tree + tree_data = {} + for t in txs_with_tree: + short_tree = t.tree[:8] + "..." + if short_tree not in tree_data: + tree_data[short_tree] = {'count': 0, 'e2e_ms': []} + tree_data[short_tree]['count'] += 1 + if t.e2e_ms is not None: + tree_data[short_tree]['e2e_ms'].append(t.e2e_ms) + + trees = sorted(tree_data.keys()) + counts = [tree_data[t]['count'] for t in trees] + + x = np.arange(len(trees)) + bars = ax.bar(x, counts, color='steelblue', alpha=0.8) + + ax.set_ylabel('TX Count') + ax.set_title(f'Transactions by Tree (n={len(txs_with_tree)})') + ax.set_xticks(x) + ax.set_xticklabels(trees, rotation=45, ha='right', fontsize=8) + + # Add count labels on bars + for bar, count in zip(bars, counts): + ax.text(bar.get_x() + bar.get_width()/2., bar.get_height(), + f'{count}', ha='center', va='bottom', fontsize=8) + + # Add avg e2e latency as secondary info + for i, tree in enumerate(trees): + e2e_times = tree_data[tree]['e2e_ms'] + if e2e_times: + avg_e2e = np.mean(e2e_times) + ax.text(i, counts[i] * 0.5, f'{avg_e2e/1000:.1f}s e2e', + ha='center', va='center', fontsize=7, color='white') + + +def plot_time_breakdown_by_type(proof_completions: list[ProofEvent], ax): + """Bar chart showing mean round-trip time by proof type.""" + proofs_with_type = [p for p in proof_completions if p.proof_type] + + if not proofs_with_type: + ax.text(0.5, 0.5, "No timing data by type", ha='center', va='center') + return + + # Aggregate by type - use round_trip_ms which is the actual end-to-end time + type_data = {} + for p in proofs_with_type: + if p.proof_type not in type_data: + type_data[p.proof_type] = [] + type_data[p.proof_type].append(p.round_trip_ms) + + types = sorted(type_data.keys()) + means = [np.mean(type_data[t]) for t in types] + medians = [np.median(type_data[t]) for t in types] + p95s = [np.percentile(type_data[t], 95) for t in types] + counts = [len(type_data[t]) for t in types] + + x = np.arange(len(types)) + width = 0.25 + + bars1 = ax.bar(x - width, means, width, label='Mean', color='steelblue') + bars2 = ax.bar(x, medians, width, label='Median', color='green') + bars3 = ax.bar(x + width, p95s, width, label='p95', color='orange') + + ax.set_ylabel('Round-trip Time (ms)') + ax.set_title(f'Round-trip Time by Proof Type (n={len(proofs_with_type)})') + ax.set_xticks(x) + ax.set_xticklabels([f"{t}\n(n={c})" for t, c in zip(types, counts)]) + ax.legend() + + # Add value labels on bars + for bars in [bars1, bars2, bars3]: + for bar in bars: + height = bar.get_height() + ax.text(bar.get_x() + bar.get_width()/2., height, + f'{height:.0f}', + ha='center', va='bottom', fontsize=7) + + +def plot_latency_timeline_by_type(proof_completions: list[ProofEvent], ax): + """Plot latency over time with color by proof type.""" + if not proof_completions: + ax.text(0.5, 0.5, "No proof data", ha='center', va='center') + return + + # Color map for proof types + type_colors = { + 'append': 'blue', + 'update': 'red', + 'address_append': 'green', + } + + for proof_type, color in type_colors.items(): + type_proofs = [p for p in proof_completions if p.proof_type == proof_type] + if type_proofs: + timestamps = [p.timestamp for p in type_proofs] + latencies = [p.round_trip_ms for p in type_proofs] + ax.scatter(timestamps, latencies, c=color, alpha=0.6, s=20, label=f'{proof_type} (n={len(type_proofs)})') + + # Handle unknown types + unknown = [p for p in proof_completions if p.proof_type not in type_colors] + if unknown: + timestamps = [p.timestamp for p in unknown] + latencies = [p.round_trip_ms for p in unknown] + ax.scatter(timestamps, latencies, c='gray', alpha=0.4, s=15, label=f'unknown (n={len(unknown)})') + + ax.set_xlabel('Time') + ax.set_ylabel('Round-trip Latency (ms)') + ax.set_title('Latency Over Time by Proof Type') + ax.legend(loc='upper left', fontsize=8) + ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M:%S')) + ax.tick_params(axis='x', rotation=45) + + +def plot_proof_vs_queue_scatter(proof_completions: list[ProofEvent], ax): + """Scatter plot of proof time vs queue wait time, highlighting cache hits.""" + proofs_with_timing = [p for p in proof_completions if p.proof_ms is not None] + + if not proofs_with_timing: + ax.text(0.5, 0.5, "No timing data", ha='center', va='center') + return + + # Separate cache hits (negative queue wait) from fresh proofs + cached = [p for p in proofs_with_timing if p.queue_wait_ms < 0] + fresh = [p for p in proofs_with_timing if p.queue_wait_ms >= 0] + + # Plot fresh proofs by type + type_colors = { + 'append': 'blue', + 'update': 'red', + 'address_append': 'green', + } + + for proof_type, color in type_colors.items(): + type_proofs = [p for p in fresh if p.proof_type == proof_type] + if type_proofs: + proof_times = [p.proof_ms for p in type_proofs] + queue_times = [p.queue_wait_ms for p in type_proofs] + ax.scatter(proof_times, queue_times, c=color, alpha=0.6, s=30, label=f'{proof_type} (n={len(type_proofs)})') + + # Unknown types (fresh) + unknown = [p for p in fresh if p.proof_type not in type_colors] + if unknown: + proof_times = [p.proof_ms for p in unknown] + queue_times = [p.queue_wait_ms for p in unknown] + ax.scatter(proof_times, queue_times, c='gray', alpha=0.4, s=20, label=f'unknown (n={len(unknown)})') + + # Plot cache hits separately (below zero line) + if cached: + proof_times = [p.proof_ms for p in cached] + queue_times = [p.queue_wait_ms for p in cached] + ax.scatter(proof_times, queue_times, c='lime', alpha=0.5, s=25, marker='v', + label=f'cache hits (n={len(cached)})') + + # Add zero line to show cache hit boundary + ax.axhline(y=0, color='black', linestyle='-', linewidth=1, alpha=0.5) + + # Add diagonal line where proof == queue (for fresh proofs only) + if fresh: + max_val = max(max(p.proof_ms for p in fresh), max(p.queue_wait_ms for p in fresh)) + ax.plot([0, max_val], [0, max_val], 'k--', alpha=0.3) + + ax.set_xlabel('Pure Proof Time (ms)') + ax.set_ylabel('Queue Wait Time (ms)\n(negative = cache hit)') + cache_pct = len(cached) / len(proofs_with_timing) * 100 if proofs_with_timing else 0 + ax.set_title(f'Proof vs Queue Wait ({cache_pct:.0f}% cache hits)') + ax.legend(loc='upper right', fontsize=7) + + +def print_summary(data: dict): + """Print summary statistics to console.""" + proof_completions = data["proof_completions"] + tx_events = data["tx_events"] + bottlenecks = data["bottlenecks"] + + print("\n" + "="*60) + print("FORESTER PERFORMANCE SUMMARY") + print("="*60) + + if proof_completions: + latencies = [p.round_trip_ms for p in proof_completions] + print(f"\nProof Latency Statistics (n={len(latencies)}):") + print(f" Min: {min(latencies):,} ms") + print(f" Max: {max(latencies):,} ms") + print(f" Mean: {np.mean(latencies):,.1f} ms") + print(f" Median: {np.median(latencies):,.1f} ms") + print(f" p95: {np.percentile(latencies, 95):,.1f} ms") + print(f" p99: {np.percentile(latencies, 99):,.1f} ms") + + # Latency buckets + print("\n Distribution:") + buckets = [(0, 500), (500, 1000), (1000, 2000), (2000, 5000), (5000, 10000), (10000, float('inf'))] + bucket_names = ["<500ms", "500-1000ms", "1-2s", "2-5s", "5-10s", ">10s"] + for (lo, hi), name in zip(buckets, bucket_names): + count = sum(1 for l in latencies if lo <= l < hi) + pct = count / len(latencies) * 100 + bar = '#' * int(pct / 2) + print(f" {name:>12}: {count:4d} ({pct:5.1f}%) {bar}") + + # Time breakdown: proof vs queue wait + proofs_with_timing = [p for p in proof_completions if p.proof_ms is not None] + if proofs_with_timing: + # Separate cache hits (pre-warmed) from fresh proofs + cached_proofs = [p for p in proofs_with_timing if p.queue_wait_ms < 0] + fresh_proofs = [p for p in proofs_with_timing if p.queue_wait_ms >= 0] + + cache_hit_rate = len(cached_proofs) / len(proofs_with_timing) * 100 + print(f"\n Cache Statistics (n={len(proofs_with_timing)} with timing data):") + print(f" Cache hits (pre-warmed): {len(cached_proofs):,} ({cache_hit_rate:.1f}%)") + print(f" Fresh proofs: {len(fresh_proofs):,} ({100-cache_hit_rate:.1f}%)") + + if cached_proofs: + cached_latencies = [p.round_trip_ms for p in cached_proofs] + print(f" Cache hit latency: {np.mean(cached_latencies):.0f}ms mean, {np.median(cached_latencies):.0f}ms median") + + if fresh_proofs: + proof_times = [p.proof_ms for p in fresh_proofs] + queue_waits = [p.queue_wait_ms for p in fresh_proofs] + + print(f"\n Time Breakdown (fresh proofs only, n={len(fresh_proofs)}):") + print(f" Pure Proof Time:") + print(f" Min: {min(proof_times):,} ms") + print(f" Max: {max(proof_times):,} ms") + print(f" Mean: {np.mean(proof_times):,.1f} ms") + print(f" Median: {np.median(proof_times):,.1f} ms") + print(f" p95: {np.percentile(proof_times, 95):,.1f} ms") + + print(f" Queue Wait Time (round_trip - proof):") + print(f" Min: {min(queue_waits):,} ms") + print(f" Max: {max(queue_waits):,} ms") + print(f" Mean: {np.mean(queue_waits):,.1f} ms") + print(f" Median: {np.median(queue_waits):,.1f} ms") + print(f" p95: {np.percentile(queue_waits, 95):,.1f} ms") + + # Percentage breakdown + total_time = sum(p.round_trip_ms for p in fresh_proofs) + total_proof = sum(proof_times) + total_queue = sum(queue_waits) + print(f"\n Time Distribution (fresh proofs):") + print(f" Proof generation: {total_proof/total_time*100:5.1f}% of total time") + print(f" Queue wait: {total_queue/total_time*100:5.1f}% of total time") + + # Latency by proof type + type_latencies = {} + for p in proof_completions: + if p.proof_type: + if p.proof_type not in type_latencies: + type_latencies[p.proof_type] = {'round_trip': [], 'proof': [], 'queue': []} + type_latencies[p.proof_type]['round_trip'].append(p.round_trip_ms) + if p.proof_ms is not None: + type_latencies[p.proof_type]['proof'].append(p.proof_ms) + type_latencies[p.proof_type]['queue'].append(p.queue_wait_ms) + + if type_latencies: + print("\n Latency by Proof Type (round_trip):") + print(f" {'Type':<18} {'Count':>6} {'Min':>8} {'p50':>8} {'Mean':>8} {'p95':>8} {'Max':>8}") + print(" " + "-"*66) + for proof_type in sorted(type_latencies.keys()): + lats = type_latencies[proof_type]['round_trip'] + if lats: + print(f" {proof_type:<18} {len(lats):>6} {min(lats):>7}ms {np.percentile(lats, 50):>7.0f}ms {np.mean(lats):>7.0f}ms {np.percentile(lats, 95):>7.0f}ms {max(lats):>7}ms") + + # Show proof time breakdown by type + has_proof_timing = any(type_latencies[t]['proof'] for t in type_latencies) + if has_proof_timing: + print("\n Pure Proof Time by Type:") + print(f" {'Type':<18} {'Count':>6} {'Min':>8} {'p50':>8} {'Mean':>8} {'p95':>8} {'Max':>8}") + print(" " + "-"*66) + for proof_type in sorted(type_latencies.keys()): + lats = type_latencies[proof_type]['proof'] + if lats: + print(f" {proof_type:<18} {len(lats):>6} {min(lats):>7}ms {np.percentile(lats, 50):>7.0f}ms {np.mean(lats):>7.0f}ms {np.percentile(lats, 95):>7.0f}ms {max(lats):>7}ms") + + print("\n Queue Wait Time by Type:") + print(f" {'Type':<18} {'Count':>6} {'Min':>8} {'p50':>8} {'Mean':>8} {'p95':>8} {'Max':>8}") + print(" " + "-"*66) + for proof_type in sorted(type_latencies.keys()): + lats = type_latencies[proof_type]['queue'] + if lats: + print(f" {proof_type:<18} {len(lats):>6} {min(lats):>7}ms {np.percentile(lats, 50):>7.0f}ms {np.mean(lats):>7.0f}ms {np.percentile(lats, 95):>7.0f}ms {max(lats):>7}ms") + + if tx_events: + total_proofs = sum(t.ixs for t in tx_events) + duration = (tx_events[-1].timestamp - tx_events[0].timestamp).total_seconds() + + print(f"\nTransaction Statistics:") + print(f" Total TXs: {len(tx_events)}") + print(f" Total Proofs: {total_proofs}") + print(f" Duration: {duration:.1f}s ({duration/60:.1f} min)") + print(f" Throughput: {total_proofs/duration*60:.1f} proofs/min" if duration > 0 else " Throughput: N/A") + + # Gap analysis + gaps = [] + for i in range(1, len(tx_events)): + gap = (tx_events[i].timestamp - tx_events[i-1].timestamp).total_seconds() + gaps.append(gap) + + if gaps: + large_gaps = [(g, i) for i, g in enumerate(gaps) if g > 10] + print(f"\n Inter-TX Gaps:") + print(f" Mean: {np.mean(gaps):.2f}s") + print(f" Max: {max(gaps):.1f}s") + print(f" Gaps >10s: {len(large_gaps)}") + total_gap_time = sum(g for g, _ in large_gaps) + print(f" Time lost in gaps: {total_gap_time:.1f}s ({total_gap_time/duration*100:.1f}%)") + + # Per-tree breakdown (new log format) + txs_with_tree = [t for t in tx_events if t.tree] + if txs_with_tree: + tree_counts = {} + tree_timing = {} + for t in txs_with_tree: + short = t.tree[:8] + tree_counts[short] = tree_counts.get(short, 0) + 1 + if t.e2e_ms is not None: + if short not in tree_timing: + tree_timing[short] = [] + tree_timing[short].append(t.e2e_ms) + + print(f"\n Per-Tree Breakdown:") + print(f" {'Tree':<12} {'TXs':>6} {'Avg E2E':>10}") + print(" " + "-"*30) + for tree in sorted(tree_counts.keys(), key=lambda x: -tree_counts[x]): + count = tree_counts[tree] + avg_e2e = f"{np.mean(tree_timing[tree])/1000:.1f}s" if tree in tree_timing else "N/A" + print(f" {tree}... {count:>6} {avg_e2e:>10}") + + if bottlenecks: + print(f"\nBottleneck Events:") + by_type = {} + for b in bottlenecks: + by_type[b.event_type] = by_type.get(b.event_type, 0) + 1 + for t, c in by_type.items(): + print(f" {t}: {c}") + + print("\n" + "="*60) + + +def main(): + parser = argparse.ArgumentParser( + description="Enhanced Forester performance analysis and visualization." + ) + parser.add_argument("logfile", type=Path, help="Path to log file") + parser.add_argument("--no-show", action="store_true", help="Don't display plots interactively") + parser.add_argument("--out", type=Path, default=None, help="Base path to save PNGs") + parser.add_argument("--summary-only", action="store_true", help="Only print summary, no plots") + + args = parser.parse_args() + + print(f"Parsing {args.logfile}...") + data = parse_log(args.logfile) + + print_summary(data) + + if args.summary_only: + return + + # Create figure with subplots (4x3 grid for detailed analysis) + fig = plt.figure(figsize=(18, 18)) + + ax1 = fig.add_subplot(4, 3, 1) + plot_latency_distribution(data["proof_completions"], ax1) + + ax2 = fig.add_subplot(4, 3, 2) + plot_latency_timeline_by_type(data["proof_completions"], ax2) + + ax3 = fig.add_subplot(4, 3, 3) + plot_time_breakdown_by_type(data["proof_completions"], ax3) + + ax4 = fig.add_subplot(4, 3, 4) + plot_proof_vs_queue_scatter(data["proof_completions"], ax4) + + ax5 = fig.add_subplot(4, 3, 5) + plot_throughput_gaps(data["tx_events"], ax5) + + ax6 = fig.add_subplot(4, 3, 6) + plot_tx_type_breakdown(data["tx_events"], ax6) + + ax7 = fig.add_subplot(4, 3, 7) + plot_pipeline_utilization(data["proof_requests"], data["proof_completions"], ax7) + + ax8 = fig.add_subplot(4, 3, 8) + plot_bottleneck_timeline(data["bottlenecks"], ax8) + + ax9 = fig.add_subplot(4, 3, 9) + plot_latency_timeline(data["proof_completions"], ax9) + + ax10 = fig.add_subplot(4, 3, 10) + plot_tx_by_tree(data["tx_events"], ax10) + + plt.tight_layout() + + if args.out: + out_path = args.out.with_suffix('.png') + plt.savefig(out_path, dpi=150) + print(f"\nSaved plot to: {out_path}") + + if not args.no_show: + plt.show() + + +if __name__ == "__main__": + main() diff --git a/forester/scripts/plot_proof_pipeline.py b/forester/scripts/plot_proof_pipeline.py new file mode 100644 index 0000000000..adce44e78c --- /dev/null +++ b/forester/scripts/plot_proof_pipeline.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python3 +""" +Visualize the proof-to-transaction pipeline timing. + +Shows: +1. When proof jobs are submitted to prover +2. When proofs complete +3. When transactions are sent + +This helps identify: +- Queue wait time (submission → completion) +- TX batching delays (proof ready → tx sent) +- Parallel proof generation patterns +""" + +import re +import sys +from datetime import datetime +from collections import defaultdict +import matplotlib.pyplot as plt +import matplotlib.dates as mdates +from matplotlib.patches import Rectangle +import numpy as np + +def parse_timestamp(ts_str): + """Parse ISO timestamp to datetime.""" + # Handle format: 2025-12-09T14:18:41.265968Z + ts_str = ts_str.rstrip('Z') + if '.' in ts_str: + return datetime.fromisoformat(ts_str) + return datetime.fromisoformat(ts_str) + +def parse_log(filename): + """Parse log file for proof lifecycle events.""" + + submissions = [] # (timestamp, job_id, seq, type, tree) + completions = [] # (timestamp, job_id, seq, round_trip_ms, proof_ms) + txs = [] # (timestamp, type, ixs, seq_range, epoch) + + # Patterns + submit_pattern = re.compile( + r'(\d{4}-\d{2}-\d{2}T[\d:.]+)Z.*Submitted proof job seq=(\d+) type=(\w+) job_id=([\w-]+)' + ) + complete_pattern = re.compile( + r'(\d{4}-\d{2}-\d{2}T[\d:.]+)Z.*Proof completed for seq=(\d+) job_id=([\w-]+) round_trip=(\d+)ms proof=(\d+)ms' + ) + # Updated pattern to capture tree, timing info + # Format: tx sent: type= ixs= tree= root= seq=.. epoch= e2e=ms + tx_pattern = re.compile( + r'(\d{4}-\d{2}-\d{2}T[\d:.]+)Z.*tx sent: \w+ type=([^\s]+) ixs=(\d+)(?: tree=(\w+))?(?: root=\[[^\]]+\])? seq=(\d+)\.\.(\d+) epoch=(\d+)(?: e2e=(\d+)ms)?' + ) + + with open(filename, 'r') as f: + for line in f: + # Remove ANSI codes + line = re.sub(r'\x1b\[[0-9;]*m', '', line) + + if m := submit_pattern.search(line): + ts = parse_timestamp(m.group(1)) + submissions.append({ + 'timestamp': ts, + 'seq': int(m.group(2)), + 'type': m.group(3), + 'job_id': m.group(4), + }) + + if m := complete_pattern.search(line): + ts = parse_timestamp(m.group(1)) + completions.append({ + 'timestamp': ts, + 'seq': int(m.group(2)), + 'job_id': m.group(3), + 'round_trip_ms': int(m.group(4)), + 'proof_ms': int(m.group(5)), + }) + + if m := tx_pattern.search(line): + ts = parse_timestamp(m.group(1)) + txs.append({ + 'timestamp': ts, + 'type': m.group(2), + 'ixs': int(m.group(3)), + 'tree': m.group(4) if m.group(4) else None, + 'seq_start': int(m.group(5)), + 'seq_end': int(m.group(6)), + 'epoch': int(m.group(7)), + 'e2e_ms': int(m.group(8)) if m.group(8) else None, + }) + + return submissions, completions, txs + +def plot_pipeline(submissions, completions, txs, output_file='proof_pipeline.png'): + """Create timeline visualization.""" + + if not submissions and not completions and not txs: + print("No data to plot!") + return + + fig, axes = plt.subplots(3, 1, figsize=(16, 12), sharex=True) + + # Get time range + all_times = [] + if submissions: + all_times.extend([s['timestamp'] for s in submissions]) + if completions: + all_times.extend([c['timestamp'] for c in completions]) + if txs: + all_times.extend([t['timestamp'] for t in txs]) + + if not all_times: + print("No timestamps found!") + return + + min_time = min(all_times) + max_time = max(all_times) + + # Convert to seconds from start + def to_seconds(dt): + return (dt - min_time).total_seconds() + + # Color map for proof types + type_colors = { + 'append': '#2ecc71', # green + 'update': '#e74c3c', # red + 'address_append': '#3498db', # blue + } + + # Plot 1: Proof Submissions + ax1 = axes[0] + ax1.set_title('Proof Job Submissions (when sent to prover)', fontsize=12, fontweight='bold') + ax1.set_ylabel('Proof Type') + + type_y = {'append': 0, 'update': 1, 'address_append': 2} + for sub in submissions: + t = to_seconds(sub['timestamp']) + y = type_y.get(sub['type'], 0) + color = type_colors.get(sub['type'], 'gray') + ax1.scatter(t, y, c=color, alpha=0.6, s=20, marker='|') + + ax1.set_yticks([0, 1, 2]) + ax1.set_yticklabels(['append', 'update', 'address_append']) + ax1.set_ylim(-0.5, 2.5) + ax1.grid(True, alpha=0.3) + + # Plot 2: Proof Completions with round-trip time + ax2 = axes[1] + ax2.set_title('Proof Completions (color = round-trip time)', fontsize=12, fontweight='bold') + ax2.set_ylabel('Round-trip (ms)') + + if completions: + times = [to_seconds(c['timestamp']) for c in completions] + round_trips = [c['round_trip_ms'] for c in completions] + + scatter = ax2.scatter(times, round_trips, c=round_trips, cmap='RdYlGn_r', + alpha=0.7, s=30, vmin=0, vmax=max(10000, max(round_trips))) + plt.colorbar(scatter, ax=ax2, label='Round-trip (ms)') + + ax2.set_yscale('log') + ax2.grid(True, alpha=0.3) + + # Plot 3: Transaction Timeline + ax3 = axes[2] + ax3.set_title('Transactions Sent', fontsize=12, fontweight='bold') + ax3.set_ylabel('Epoch') + ax3.set_xlabel('Time (seconds from start)') + + if txs: + times = [to_seconds(t['timestamp']) for t in txs] + epochs = [t['epoch'] for t in txs] + + # Color by tx type + tx_colors = [] + for tx in txs: + if 'Append+Nullify' in tx['type']: + tx_colors.append('#9b59b6') # purple + elif 'Append' in tx['type']: + tx_colors.append('#2ecc71') # green + elif 'Nullify' in tx['type']: + tx_colors.append('#e74c3c') # red + elif 'Address' in tx['type']: + tx_colors.append('#3498db') # blue + else: + tx_colors.append('gray') + + ax3.scatter(times, epochs, c=tx_colors, alpha=0.7, s=50, marker='s') + + # Add vertical lines for epoch boundaries + epoch_changes = [] + prev_epoch = None + for tx in txs: + if prev_epoch is not None and tx['epoch'] != prev_epoch: + epoch_changes.append(to_seconds(tx['timestamp'])) + prev_epoch = tx['epoch'] + + for ec in epoch_changes: + ax3.axvline(x=ec, color='red', linestyle='--', alpha=0.5, linewidth=1) + + ax3.grid(True, alpha=0.3) + + # Add legend + from matplotlib.lines import Line2D + legend_elements = [ + Line2D([0], [0], marker='s', color='w', markerfacecolor='#2ecc71', markersize=10, label='Append'), + Line2D([0], [0], marker='s', color='w', markerfacecolor='#e74c3c', markersize=10, label='Nullify'), + Line2D([0], [0], marker='s', color='w', markerfacecolor='#3498db', markersize=10, label='AddressAppend'), + Line2D([0], [0], marker='s', color='w', markerfacecolor='#9b59b6', markersize=10, label='Append+Nullify'), + Line2D([0], [0], color='red', linestyle='--', alpha=0.5, label='Epoch change'), + ] + ax3.legend(handles=legend_elements, loc='upper right') + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches='tight') + print(f"Saved pipeline visualization to {output_file}") + + # Print statistics + print(f"\n{'='*60}") + print("PROOF PIPELINE STATISTICS") + print('='*60) + print(f"Total duration: {(max_time - min_time).total_seconds():.1f}s") + print(f"Proof submissions: {len(submissions)}") + print(f"Proof completions: {len(completions)}") + print(f"Transactions sent: {len(txs)}") + + if completions: + round_trips = [c['round_trip_ms'] for c in completions] + print(f"\nRound-trip times:") + print(f" Min: {min(round_trips)}ms") + print(f" Max: {max(round_trips)}ms") + print(f" Mean: {np.mean(round_trips):.0f}ms") + print(f" Median: {np.median(round_trips):.0f}ms") + + if txs: + # Calculate inter-tx gaps + tx_times = sorted([t['timestamp'] for t in txs]) + gaps = [(tx_times[i+1] - tx_times[i]).total_seconds() for i in range(len(tx_times)-1)] + if gaps: + print(f"\nInter-TX gaps:") + print(f" Max gap: {max(gaps):.1f}s") + print(f" Gaps > 5s: {sum(1 for g in gaps if g > 5)}") + print(f" Gaps > 10s: {sum(1 for g in gaps if g > 10)}") + + # TX timing stats (new log format) + tx_e2e_times = [t['e2e_ms'] for t in txs if t.get('e2e_ms') is not None] + if tx_e2e_times: + print(f"\nTX end-to-end latency (proof submit → tx sent):") + print(f" Min: {min(tx_e2e_times)}ms") + print(f" Max: {max(tx_e2e_times)}ms") + print(f" Mean: {np.mean(tx_e2e_times):.0f}ms") + print(f" Median: {np.median(tx_e2e_times):.0f}ms") + + # Per-tree stats + trees = set(t.get('tree') for t in txs if t.get('tree')) + if trees: + print(f"\nTXs per tree:") + for tree in sorted(trees): + count = sum(1 for t in txs if t.get('tree') == tree) + print(f" {tree[:8]}...: {count} txs") + +def plot_proof_lifecycle(submissions, completions, output_file='proof_lifecycle.png'): + """Create Gantt-style chart showing proof lifecycles.""" + + # Match submissions to completions by job_id + job_lifecycles = {} + + for sub in submissions: + job_id = sub['job_id'] + job_lifecycles[job_id] = { + 'submit_time': sub['timestamp'], + 'type': sub['type'], + 'seq': sub['seq'], + } + + for comp in completions: + job_id = comp['job_id'] + if job_id in job_lifecycles: + job_lifecycles[job_id]['complete_time'] = comp['timestamp'] + job_lifecycles[job_id]['round_trip_ms'] = comp['round_trip_ms'] + + # Filter to jobs with both submit and complete + complete_jobs = {k: v for k, v in job_lifecycles.items() + if 'complete_time' in v} + + if not complete_jobs: + print("No complete job lifecycles found!") + return + + # Sort by submit time + sorted_jobs = sorted(complete_jobs.values(), key=lambda x: x['submit_time']) + + # Limit to first 100 for readability + sorted_jobs = sorted_jobs[:100] + + fig, ax = plt.subplots(figsize=(16, 10)) + + min_time = min(j['submit_time'] for j in sorted_jobs) + + type_colors = { + 'append': '#2ecc71', + 'update': '#e74c3c', + 'address_append': '#3498db', + } + + for i, job in enumerate(sorted_jobs): + start = (job['submit_time'] - min_time).total_seconds() + end = (job['complete_time'] - min_time).total_seconds() + duration = end - start + + color = type_colors.get(job['type'], 'gray') + + # Draw bar from submit to complete + ax.barh(i, duration, left=start, height=0.8, + color=color, alpha=0.7, edgecolor='black', linewidth=0.5) + + ax.set_xlabel('Time (seconds from start)') + ax.set_ylabel('Proof Job (ordered by submission)') + ax.set_title('Proof Job Lifecycles (submit → complete)', fontsize=14, fontweight='bold') + + # Legend + from matplotlib.patches import Patch + legend_elements = [ + Patch(facecolor='#2ecc71', label='append'), + Patch(facecolor='#e74c3c', label='update/nullify'), + Patch(facecolor='#3498db', label='address_append'), + ] + ax.legend(handles=legend_elements, loc='upper right') + + ax.grid(True, alpha=0.3, axis='x') + + plt.tight_layout() + plt.savefig(output_file, dpi=150, bbox_inches='tight') + print(f"Saved lifecycle chart to {output_file}") + +if __name__ == '__main__': + if len(sys.argv) < 2: + print("Usage: python plot_proof_pipeline.py ") + sys.exit(1) + + logfile = sys.argv[1] + print(f"Parsing {logfile}...") + + submissions, completions, txs = parse_log(logfile) + + print(f"Found {len(submissions)} submissions, {len(completions)} completions, {len(txs)} txs") + + # Generate both visualizations + plot_pipeline(submissions, completions, txs) + plot_proof_lifecycle(submissions, completions) diff --git a/forester/src/cli.rs b/forester/src/cli.rs index 73b15825d9..f64c03322f 100644 --- a/forester/src/cli.rs +++ b/forester/src/cli.rs @@ -128,6 +128,22 @@ pub struct StartArgs { )] pub ops_cache_ttl_seconds: u64, + #[arg( + long, + env = "FORESTER_CONFIRMATION_MAX_ATTEMPTS", + default_value = "60", + help = "Maximum attempts to confirm a transaction before timing out" + )] + pub confirmation_max_attempts: u32, + + #[arg( + long, + env = "FORESTER_CONFIRMATION_POLL_INTERVAL_MS", + default_value = "500", + help = "Interval between confirmation polling attempts in milliseconds" + )] + pub confirmation_poll_interval_ms: u64, + #[arg(long, env = "FORESTER_CU_LIMIT", default_value = "1000000")] pub cu_limit: u32, @@ -230,18 +246,34 @@ pub struct StartArgs { #[arg( long, - env = "FORESTER_TREE_ID", - help = "Process only the specified tree (Pubkey). If specified, forester will process only this tree and ignore all others" + env = "FORESTER_QUEUE_POLLING_MODE", + default_value_t = QueuePollingMode::Indexer, + help = "Queue polling mode: indexer (poll indexer API, requires indexer_url), onchain (read queue status directly from RPC)" + )] + pub queue_polling_mode: QueuePollingMode, + + #[arg( + long = "tree-id", + env = "FORESTER_TREE_IDS", + help = "Process only the specified trees (Pubkeys). Can be specified multiple times. If specified, forester will process only these trees and ignore all others", + value_delimiter = ',' )] - pub tree_id: Option, + pub tree_ids: Vec, #[arg( long, env = "FORESTER_ENABLE_COMPRESSIBLE", - help = "Enable compressible account tracking and compression using ws_rpc_url", + help = "Enable compressible account tracking and compression using ws_rpc_url (requires --ws-rpc-url)", default_value = "false" )] pub enable_compressible: bool, + + #[arg( + long, + env = "FORESTER_LOOKUP_TABLE_ADDRESS", + help = "Address lookup table pubkey for versioned transactions. If not provided, legacy transactions will be used." + )] + pub lookup_table_address: Option, } #[derive(Parser, Clone, Debug)] @@ -320,6 +352,18 @@ pub enum ProcessorMode { All, } +/// Queue polling mode determines how the forester discovers pending queue items. +#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub enum QueuePollingMode { + /// Poll the indexer API for queue status (requires indexer_url) + #[clap(name = "indexer")] + #[default] + Indexer, + /// Read queue status directly from on-chain accounts via RPC + #[clap(name = "onchain")] + OnChain, +} + impl std::fmt::Display for ProcessorMode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -330,6 +374,15 @@ impl std::fmt::Display for ProcessorMode { } } +impl std::fmt::Display for QueuePollingMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + QueuePollingMode::Indexer => write!(f, "indexer"), + QueuePollingMode::OnChain => write!(f, "onchain"), + } + } +} + #[cfg(test)] mod tests { use clap::Parser; diff --git a/forester/src/compressible/bootstrap.rs b/forester/src/compressible/bootstrap.rs index 8d3216c50e..a2ab22df44 100644 --- a/forester/src/compressible/bootstrap.rs +++ b/forester/src/compressible/bootstrap.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use base64::{engine::general_purpose, Engine as _}; use borsh::BorshDeserialize; use light_ctoken_interface::{ state::{extensions::ExtensionStruct, CToken}, @@ -97,7 +96,7 @@ fn process_account( } }; - let data_bytes = match general_purpose::STANDARD.decode(data_str) { + let data_bytes = match base64::decode(data_str) { Ok(bytes) => bytes, Err(e) => { debug!("Failed to decode base64 for account {}: {:?}", pubkey, e); diff --git a/forester/src/compressible/subscriber.rs b/forester/src/compressible/subscriber.rs index c2d857747a..7f73f4bc68 100644 --- a/forester/src/compressible/subscriber.rs +++ b/forester/src/compressible/subscriber.rs @@ -121,18 +121,13 @@ impl AccountSubscriber { use solana_account_decoder::UiAccountData; let account_data = match &response.value.account.data { UiAccountData::Binary(data, encoding) => match encoding { - solana_account_decoder::UiAccountEncoding::Base64 => { - match base64::engine::Engine::decode( - &base64::engine::general_purpose::STANDARD, - data, - ) { - Ok(decoded) => decoded, - Err(e) => { - error!("Failed to decode base64 for {}: {}", pubkey, e); - return; - } + solana_account_decoder::UiAccountEncoding::Base64 => match base64::decode(data) { + Ok(decoded) => decoded, + Err(e) => { + error!("Failed to decode base64 for {}: {}", pubkey, e); + return; } - } + }, _ => { error!("Unexpected encoding for account {}", pubkey); return; diff --git a/forester/src/config.rs b/forester/src/config.rs index 73853a2251..dd7c96121c 100644 --- a/forester/src/config.rs +++ b/forester/src/config.rs @@ -8,7 +8,7 @@ use light_registry::{EpochPda, ForesterEpochPda}; use solana_sdk::{pubkey::Pubkey, signature::Keypair}; use crate::{ - cli::{ProcessorMode, StartArgs, StatusArgs}, + cli::{ProcessorMode, QueuePollingMode, StartArgs, StatusArgs}, errors::ConfigError, Result, }; @@ -28,6 +28,8 @@ pub struct ForesterConfig { pub address_tree_data: Vec, pub state_tree_data: Vec, pub compressible_config: Option, + /// Address lookup table for versioned transactions. If None, legacy transactions are used. + pub lookup_table_address: Option, } #[derive(Debug, Clone)] @@ -74,6 +76,10 @@ pub struct TransactionConfig { pub enable_priority_fees: bool, pub tx_cache_ttl_seconds: u64, pub ops_cache_ttl_seconds: u64, + /// Maximum attempts to confirm a transaction before timing out. + pub confirmation_max_attempts: u32, + /// Interval between confirmation polling attempts in milliseconds. + pub confirmation_poll_interval_ms: u64, } #[derive(Debug, Clone)] @@ -85,9 +91,10 @@ pub struct GeneralConfig { pub skip_v1_address_trees: bool, pub skip_v2_state_trees: bool, pub skip_v2_address_trees: bool, - pub tree_id: Option, + pub tree_ids: Vec, pub sleep_after_processing_ms: u64, pub sleep_when_idle_ms: u64, + pub queue_polling_mode: QueuePollingMode, } impl Default for GeneralConfig { @@ -100,9 +107,10 @@ impl Default for GeneralConfig { skip_v1_address_trees: false, skip_v2_state_trees: false, skip_v2_address_trees: false, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 10_000, sleep_when_idle_ms: 45_000, + queue_polling_mode: QueuePollingMode::Indexer, } } } @@ -117,9 +125,10 @@ impl GeneralConfig { skip_v1_address_trees: true, skip_v2_state_trees: true, skip_v2_address_trees: false, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, + queue_polling_mode: QueuePollingMode::Indexer, } } @@ -132,9 +141,10 @@ impl GeneralConfig { skip_v1_address_trees: true, skip_v2_state_trees: false, skip_v2_address_trees: true, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, + queue_polling_mode: QueuePollingMode::Indexer, } } } @@ -179,6 +189,8 @@ impl Default for TransactionConfig { enable_priority_fees: false, tx_cache_ttl_seconds: 15, ops_cache_ttl_seconds: 180, + confirmation_max_attempts: 60, + confirmation_poll_interval_ms: 500, } } } @@ -195,8 +207,11 @@ impl ForesterConfig { } None => return Err(ConfigError::MissingField { field: "payer" })?, }; - let payer = Keypair::try_from(payer.as_slice()) - .map_err(|e| ConfigError::InvalidKeypair(e.to_string()))?; + let payer = + Keypair::try_from(payer.as_slice()).map_err(|e| ConfigError::InvalidArguments { + field: "payer", + invalid_values: vec![e.to_string()], + })?; let derivation: Vec = match &args.derivation { Some(derivation_str) => { @@ -214,8 +229,9 @@ impl ForesterConfig { let derivation_array: [u8; 32] = derivation .try_into() - .map_err(|_| ConfigError::InvalidDerivation { - reason: "must be exactly 32 bytes".to_string(), + .map_err(|_| ConfigError::InvalidArguments { + field: "derivation", + invalid_values: vec!["must be exactly 32 bytes".to_string()], })?; let derivation = Pubkey::from(derivation_array); @@ -276,6 +292,8 @@ impl ForesterConfig { enable_priority_fees: args.enable_priority_fees, tx_cache_ttl_seconds: args.tx_cache_ttl_seconds, ops_cache_ttl_seconds: args.ops_cache_ttl_seconds, + confirmation_max_attempts: args.confirmation_max_attempts, + confirmation_poll_interval_ms: args.confirmation_poll_interval_ms, }, general_config: GeneralConfig { slot_update_interval_seconds: args.slot_update_interval_seconds, @@ -285,12 +303,28 @@ impl ForesterConfig { skip_v2_state_trees: args.processor_mode == ProcessorMode::V1, skip_v1_address_trees: args.processor_mode == ProcessorMode::V2, skip_v2_address_trees: args.processor_mode == ProcessorMode::V1, - tree_id: args - .tree_id - .as_ref() - .and_then(|id| Pubkey::from_str(id).ok()), + tree_ids: { + let (valid, invalid): (Vec<_>, Vec<_>) = args + .tree_ids + .iter() + .map(|id| Pubkey::from_str(id).map_err(|_| id.clone())) + .partition(|r| r.is_ok()); + + if !invalid.is_empty() { + let invalid_values: Vec = + invalid.into_iter().map(|r| r.unwrap_err()).collect(); + return Err(ConfigError::InvalidArguments { + field: "tree_ids", + invalid_values, + } + .into()); + } + + valid.into_iter().map(|r| r.unwrap()).collect() + }, sleep_after_processing_ms: 10_000, sleep_when_idle_ms: 45_000, + queue_polling_mode: args.queue_polling_mode, }, rpc_pool_config: RpcPoolConfig { max_size: args.rpc_pool_size, @@ -301,9 +335,9 @@ impl ForesterConfig { max_retry_delay_ms: args.rpc_pool_max_retry_delay_ms, }, registry_pubkey: Pubkey::from_str(®istry_pubkey).map_err(|e| { - ConfigError::InvalidPubkey { + ConfigError::InvalidArguments { field: "registry_pubkey", - error: e.to_string(), + invalid_values: vec![e.to_string()], } })?, payer_keypair: payer, @@ -311,12 +345,34 @@ impl ForesterConfig { address_tree_data: vec![], state_tree_data: vec![], compressible_config: if args.enable_compressible { - args.ws_rpc_url - .clone() - .map(crate::compressible::config::CompressibleConfig::new) + match &args.ws_rpc_url { + Some(ws_url) => Some(crate::compressible::config::CompressibleConfig::new( + ws_url.clone(), + )), + None => { + return Err(ConfigError::InvalidArguments { + field: "enable_compressible", + invalid_values: vec![ + "--ws-rpc-url is required when --enable-compressible is true" + .to_string(), + ], + } + .into()) + } + } } else { None }, + lookup_table_address: args + .lookup_table_address + .as_ref() + .map(|s| { + Pubkey::from_str(s).map_err(|e| ConfigError::InvalidArguments { + field: "lookup_table_address", + invalid_values: vec![e.to_string()], + }) + }) + .transpose()?, }) } @@ -355,9 +411,10 @@ impl ForesterConfig { skip_v2_state_trees: false, skip_v1_address_trees: false, skip_v2_address_trees: false, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 10_000, sleep_when_idle_ms: 45_000, + queue_polling_mode: QueuePollingMode::OnChain, // Status uses on-chain reads }, rpc_pool_config: RpcPoolConfig { max_size: 10, @@ -373,6 +430,7 @@ impl ForesterConfig { address_tree_data: vec![], state_tree_data: vec![], compressible_config: None, + lookup_table_address: None, }) } } @@ -392,6 +450,7 @@ impl Clone for ForesterConfig { address_tree_data: self.address_tree_data.clone(), state_tree_data: self.state_tree_data.clone(), compressible_config: self.compressible_config.clone(), + lookup_table_address: self.lookup_table_address, } } } diff --git a/forester/src/epoch_manager.rs b/forester/src/epoch_manager.rs index e5aa9140d4..e1c42200ea 100644 --- a/forester/src/epoch_manager.rs +++ b/forester/src/epoch_manager.rs @@ -8,19 +8,23 @@ use std::{ }; use anyhow::{anyhow, Context}; -use dashmap::DashMap; +use borsh::BorshSerialize; +use dashmap::{mapref::entry::Entry, DashMap}; use forester_utils::{ forester_epoch::{get_epoch_phases, Epoch, ForesterSlot, TreeAccounts, TreeForesterSchedule}, rpc_pool::SolanaRpcPool, }; use futures::future::join_all; -use kameo::actor::{ActorRef, Spawn}; use light_client::{ indexer::{MerkleProof, NewAddressProofWithContext}, rpc::{LightClient, LightClientConfig, RetryConfig, Rpc, RpcError}, }; use light_compressed_account::TreeType; use light_registry::{ + account_compression_cpi::sdk::{ + create_batch_append_instruction, create_batch_nullify_instruction, + create_batch_update_address_tree_instruction, + }, protocol_config::state::{EpochState, ProtocolConfig}, sdk::{create_finalize_registration_instruction, create_report_work_instruction}, utils::{get_epoch_pda_address, get_forester_epoch_pda_from_authority}, @@ -30,6 +34,7 @@ use solana_program::{ instruction::InstructionError, native_token::LAMPORTS_PER_SOL, pubkey::Pubkey, }; use solana_sdk::{ + address_lookup_table::AddressLookupTableAccount, signature::{Keypair, Signer}, transaction::TransactionError, }; @@ -47,7 +52,6 @@ use crate::{ }, metrics::{push_metrics, queue_metric_update, update_forester_sol_balance}, pagerduty::send_pagerduty_alert, - polling::{QueueInfoPoller, QueueUpdateMessage, RegisterTree}, processor::{ tx_cache::ProcessedHashCache, v1::{ @@ -55,7 +59,12 @@ use crate::{ send_transaction::send_batched_transactions, tx_builder::EpochManagerTransactions, }, - v2::{self, process_batched_operations, BatchContext, ProverConfig}, + v2::{ + errors::V2Error, + strategy::{AddressTreeStrategy, StateTreeStrategy}, + BatchContext, BatchInstruction, ProcessingResult, ProverConfig, QueueProcessor, + SharedProofCache, + }, }, queue_helpers::QueueItemData, rollover::{ @@ -64,17 +73,97 @@ use crate::{ }, slot_tracker::{slot_duration, wait_until_slot_reached, SlotTracker}, tree_data_sync::fetch_trees, - tree_finder::TreeFinder, ForesterConfig, ForesterEpochInfo, Result, }; -/// Map of tree pubkey to (epoch, supervisor actor reference) -type StateSupervisorMap = Arc>)>>; +fn is_v2_error(err: &anyhow::Error, predicate: impl FnOnce(&V2Error) -> bool) -> bool { + err.downcast_ref::().is_some_and(predicate) +} + +type StateBatchProcessorMap = + Arc>>)>>; +type AddressBatchProcessorMap = + Arc>>)>>; + +/// Timing for a single circuit type (circuit inputs + proof generation) +#[derive(Copy, Clone, Debug, Default)] +pub struct CircuitMetrics { + /// Time spent building circuit inputs + pub circuit_inputs_duration: std::time::Duration, + /// Time spent generating ZK proofs (pure prover server time) + pub proof_generation_duration: std::time::Duration, + /// Total round-trip time (submit to result, includes queue wait) + pub round_trip_duration: std::time::Duration, +} + +impl CircuitMetrics { + pub fn total(&self) -> std::time::Duration { + self.circuit_inputs_duration + self.proof_generation_duration + } +} + +impl std::ops::AddAssign for CircuitMetrics { + fn add_assign(&mut self, rhs: Self) { + self.circuit_inputs_duration += rhs.circuit_inputs_duration; + self.proof_generation_duration += rhs.proof_generation_duration; + self.round_trip_duration += rhs.round_trip_duration; + } +} + +/// Timing breakdown by circuit type +#[derive(Copy, Clone, Debug, Default)] +pub struct ProcessingMetrics { + /// State append circuit (output queue processing) + pub append: CircuitMetrics, + /// State nullify circuit (input queue processing) + pub nullify: CircuitMetrics, + /// Address append circuit + pub address_append: CircuitMetrics, + /// Time spent sending transactions (overlapped with proof gen) + pub tx_sending_duration: std::time::Duration, +} + +impl ProcessingMetrics { + pub fn total(&self) -> std::time::Duration { + self.append.total() + + self.nullify.total() + + self.address_append.total() + + self.tx_sending_duration + } + + pub fn total_circuit_inputs(&self) -> std::time::Duration { + self.append.circuit_inputs_duration + + self.nullify.circuit_inputs_duration + + self.address_append.circuit_inputs_duration + } + + pub fn total_proof_generation(&self) -> std::time::Duration { + self.append.proof_generation_duration + + self.nullify.proof_generation_duration + + self.address_append.proof_generation_duration + } + + pub fn total_round_trip(&self) -> std::time::Duration { + self.append.round_trip_duration + + self.nullify.round_trip_duration + + self.address_append.round_trip_duration + } +} + +impl std::ops::AddAssign for ProcessingMetrics { + fn add_assign(&mut self, rhs: Self) { + self.append += rhs.append; + self.nullify += rhs.nullify; + self.address_append += rhs.address_append; + self.tx_sending_duration += rhs.tx_sending_duration; + } +} #[derive(Copy, Clone, Debug)] pub struct WorkReport { pub epoch: u64, pub processed_items: usize, + pub metrics: ProcessingMetrics, } #[derive(Debug, Clone)] @@ -107,15 +196,21 @@ pub struct EpochManager { authority: Arc, work_report_sender: mpsc::Sender, processed_items_per_epoch_count: Arc>>, + processing_metrics_per_epoch: Arc>>, trees: Arc>>, slot_tracker: Arc, processing_epochs: Arc>>, new_tree_sender: broadcast::Sender, tx_cache: Arc>, ops_cache: Arc>, - queue_poller: Option>, - state_supervisors: StateSupervisorMap, + /// Proof caches for pre-warming during idle slots + proof_caches: Arc>>, + state_processors: StateBatchProcessorMap, + address_processors: AddressBatchProcessorMap, compressible_tracker: Option>, + /// Cached zkp_batch_size per tree to filter queue updates below threshold + zkp_batch_sizes: Arc>, + address_lookup_tables: Arc>, } impl Clone for EpochManager { @@ -127,15 +222,19 @@ impl Clone for EpochManager { authority: self.authority.clone(), work_report_sender: self.work_report_sender.clone(), processed_items_per_epoch_count: self.processed_items_per_epoch_count.clone(), + processing_metrics_per_epoch: self.processing_metrics_per_epoch.clone(), trees: self.trees.clone(), slot_tracker: self.slot_tracker.clone(), processing_epochs: self.processing_epochs.clone(), new_tree_sender: self.new_tree_sender.clone(), tx_cache: self.tx_cache.clone(), ops_cache: self.ops_cache.clone(), - queue_poller: self.queue_poller.clone(), - state_supervisors: self.state_supervisors.clone(), + proof_caches: self.proof_caches.clone(), + state_processors: self.state_processors.clone(), + address_processors: self.address_processors.clone(), compressible_tracker: self.compressible_tracker.clone(), + zkp_batch_sizes: self.zkp_batch_sizes.clone(), + address_lookup_tables: self.address_lookup_tables.clone(), } } } @@ -153,26 +252,8 @@ impl EpochManager { tx_cache: Arc>, ops_cache: Arc>, compressible_tracker: Option>, + address_lookup_tables: Arc>, ) -> Result { - let queue_poller = if let Some(indexer_url) = &config.external_services.indexer_url { - info!( - "Spawning QueueInfoPoller actor for indexer at {}", - indexer_url - ); - - let poller = QueueInfoPoller::new( - indexer_url.clone(), - config.external_services.photon_api_key.clone(), - ); - - let actor_ref = QueueInfoPoller::spawn(poller); - info!("QueueInfoPoller actor spawn initiated"); - Some(actor_ref) - } else { - info!("indexer_url not configured, V2 trees will not have queue updates"); - None - }; - let authority = Arc::new(config.payer_keypair.insecure_clone()); Ok(Self { config, @@ -181,15 +262,19 @@ impl EpochManager { authority, work_report_sender, processed_items_per_epoch_count: Arc::new(Mutex::new(HashMap::new())), + processing_metrics_per_epoch: Arc::new(Mutex::new(HashMap::new())), trees: Arc::new(Mutex::new(trees)), slot_tracker, processing_epochs: Arc::new(DashMap::new()), new_tree_sender, tx_cache, ops_cache, - queue_poller, - state_supervisors: Arc::new(DashMap::new()), + proof_caches: Arc::new(DashMap::new()), + state_processors: Arc::new(DashMap::new()), + address_processors: Arc::new(DashMap::new()), compressible_tracker, + zkp_batch_sizes: Arc::new(DashMap::new()), + address_lookup_tables, }) } @@ -224,6 +309,22 @@ impl EpochManager { async move { self_clone.check_sol_balance_periodically().await } }); + let _guard = scopeguard::guard( + ( + monitor_handle, + current_previous_handle, + new_tree_handle, + balance_check_handle, + ), + |(h1, h2, h3, h4)| { + info!("Aborting EpochManager background tasks"); + h1.abort(); + h2.abort(); + h3.abort(); + h4.abort(); + }, + ); + while let Some(epoch) = rx.recv().await { debug!("Received new epoch: {}", epoch); @@ -235,11 +336,6 @@ impl EpochManager { }); } - monitor_handle.await??; - current_previous_handle.await??; - new_tree_handle.await??; - balance_check_handle.await??; - Ok(()) } @@ -272,7 +368,10 @@ impl EpochManager { match receiver.recv().await { Ok(new_tree) => { info!("Received new tree: {:?}", new_tree); - self.add_new_tree(new_tree).await?; + if let Err(e) = self.add_new_tree(new_tree).await { + error!("Failed to add new tree: {:?}", e); + // Continue processing other trees instead of crashing + } } Err(e) => match e { RecvError::Lagged(lag) => { @@ -373,73 +472,109 @@ impl EpochManager { if last_epoch.is_none_or(|last| current_epoch > last) { debug!("New epoch detected: {}", current_epoch); - // Kill state supervisors and clear caches when a new epoch is detected - let supervisor_count = self.state_supervisors.len(); - if supervisor_count > 0 { - for entry in self.state_supervisors.iter() { - let (_, actor_ref) = entry.value(); - actor_ref.kill(); - } - self.state_supervisors.clear(); - info!( - "Killed and cleared {} state supervisor actors for new epoch {}", - supervisor_count, current_epoch - ); - } let phases = get_epoch_phases(&self.protocol_config, current_epoch); if slot < phases.registration.end { debug!("Sending current epoch {} for processing", current_epoch); - tx.send(current_epoch).await?; + if let Err(e) = tx.send(current_epoch).await { + error!( + "Failed to send current epoch {} for processing: {:?}", + current_epoch, e + ); + return Ok(()); + } last_epoch = Some(current_epoch); } } - let next_epoch = current_epoch + 1; - if last_epoch.is_none_or(|last| next_epoch > last) { - let next_phases = get_epoch_phases(&self.protocol_config, next_epoch); + // Find the next epoch we can register for (scan forward if needed) + let mut target_epoch = current_epoch + 1; + if last_epoch.is_none_or(|last| target_epoch > last) { + // Scan forward to find an epoch whose registration is still open + // This handles the case where we missed multiple epochs + loop { + let target_phases = get_epoch_phases(&self.protocol_config, target_epoch); + + // If registration hasn't started yet, wait for it + if slot < target_phases.registration.start { + let mut rpc = match self.rpc_pool.get_connection().await { + Ok(rpc) => rpc, + Err(e) => { + warn!("Failed to get RPC connection for slot waiting: {:?}", e); + tokio::time::sleep(Duration::from_secs(1)).await; + break; + } + }; - // If the next epoch's registration phase has started, send it immediately - if slot >= next_phases.registration.start && slot < next_phases.registration.end { - debug!( - "Next epoch {} registration phase already started, sending for processing", - next_epoch - ); - tx.send(next_epoch).await?; - last_epoch = Some(next_epoch); - continue; // Check for further epochs immediately - } + const REGISTRATION_BUFFER_SLOTS: u64 = 30; + let wait_target = target_phases + .registration + .start + .saturating_sub(REGISTRATION_BUFFER_SLOTS); + let slots_to_wait = wait_target.saturating_sub(slot); - // Otherwise, wait for the next epoch's registration phase to start - let mut rpc = self.rpc_pool.get_connection().await?; - let slots_to_wait = next_phases.registration.start.saturating_sub(slot); - debug!( - "Waiting for epoch {} registration phase to start. Current slot: {}, Registration phase start slot: {}, Slots to wait: {}", - next_epoch, slot, next_phases.registration.start, slots_to_wait - ); + debug!( + "Waiting for epoch {} registration phase. Current slot: {}, Wait target: {} (registration starts at {}), Slots to wait: {}", + target_epoch, slot, wait_target, target_phases.registration.start, slots_to_wait + ); - if let Err(e) = wait_until_slot_reached( - &mut *rpc, - &self.slot_tracker, - next_phases.registration.start, - ) - .await - { - error!("Error waiting for next registration phase: {:?}", e); - continue; - } + if let Err(e) = + wait_until_slot_reached(&mut *rpc, &self.slot_tracker, wait_target) + .await + { + error!("Error waiting for registration phase: {:?}", e); + break; + } - debug!( - "Next epoch {} registration phase started, sending for processing", - next_epoch - ); - if let Err(e) = tx.send(next_epoch).await { - error!( - "Failed to send next epoch {} for processing: {:?}", - next_epoch, e + let current_slot = self.slot_tracker.estimated_current_slot(); + if current_slot >= target_phases.registration.end { + debug!( + "Epoch {} registration ended while waiting (current slot {} >= end {}), trying next epoch", + target_epoch, current_slot, target_phases.registration.end + ); + target_epoch += 1; + continue; + } + + debug!( + "Epoch {} registration phase ready, sending for processing (current slot: {}, registration end: {})", + target_epoch, current_slot, target_phases.registration.end + ); + if let Err(e) = tx.send(target_epoch).await { + error!( + "Failed to send epoch {} for processing: {:?}", + target_epoch, e + ); + break; + } + last_epoch = Some(target_epoch); + break; + } + + // If we're within the registration window, send it + if slot < target_phases.registration.end { + debug!( + "Epoch {} registration phase is open (slot {} < end {}), sending for processing", + target_epoch, slot, target_phases.registration.end + ); + if let Err(e) = tx.send(target_epoch).await { + error!( + "Failed to send epoch {} for processing: {:?}", + target_epoch, e + ); + break; + } + last_epoch = Some(target_epoch); + break; + } + + // Registration already ended, try next epoch + debug!( + "Epoch {} registration already ended (slot {} >= end {}), checking next epoch", + target_epoch, slot, target_phases.registration.end ); - continue; + target_epoch += 1; } - last_epoch = Some(next_epoch); + continue; // Re-check state after processing } else { // we've already sent the next epoch, wait a bit before checking again tokio::time::sleep(Duration::from_secs(10)).await; @@ -462,6 +597,16 @@ impl EpochManager { .fetch_add(increment_by, Ordering::Relaxed); } + async fn get_processing_metrics(&self, epoch: u64) -> ProcessingMetrics { + let metrics = self.processing_metrics_per_epoch.lock().await; + metrics.get(&epoch).copied().unwrap_or_default() + } + + async fn add_processing_metrics(&self, epoch: u64, new_metrics: ProcessingMetrics) { + let mut metrics = self.processing_metrics_per_epoch.lock().await; + *metrics.entry(epoch).or_default() += new_metrics; + } + async fn recover_registration_info(&self, epoch: u64) -> Result { debug!("Recovering registration info for epoch {}", epoch); @@ -492,7 +637,10 @@ impl EpochManager { // Process the previous epoch if still in active or later phase if slot > current_phases.registration.start { debug!("Processing previous epoch: {}", previous_epoch); - tx.send(previous_epoch).await?; + if let Err(e) = tx.send(previous_epoch).await { + error!("Failed to send previous epoch for processing: {:?}", e); + return Ok(()); + } } // Only process current epoch if we can still register or are already registered @@ -502,7 +650,10 @@ impl EpochManager { "Processing current epoch: {} (registration still open)", current_epoch ); - tx.send(current_epoch).await?; + if let Err(e) = tx.send(current_epoch).await { + error!("Failed to send current epoch for processing: {:?}", e); + return Ok(()); // Channel closed, exit gracefully + } } else { // Check if we're already registered for this epoch let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority( @@ -510,21 +661,33 @@ impl EpochManager { current_epoch, ) .0; - let rpc = self.rpc_pool.get_connection().await?; - if let Ok(Some(_)) = rpc - .get_anchor_account::(&forester_epoch_pda_pubkey) - .await - { - debug!( - "Processing current epoch: {} (already registered)", - current_epoch - ); - tx.send(current_epoch).await?; - } else { - warn!( - "Skipping current epoch {} - registration ended at slot {} (current slot: {})", - current_epoch, current_phases.registration.end, slot - ); + match self.rpc_pool.get_connection().await { + Ok(rpc) => { + if let Ok(Some(_)) = rpc + .get_anchor_account::(&forester_epoch_pda_pubkey) + .await + { + debug!( + "Processing current epoch: {} (already registered)", + current_epoch + ); + if let Err(e) = tx.send(current_epoch).await { + error!("Failed to send current epoch for processing: {:?}", e); + return Ok(()); // Channel closed, exit gracefully + } + } else { + warn!( + "Skipping current epoch {} - registration ended at slot {} (current slot: {})", + current_epoch, current_phases.registration.end, slot + ); + } + } + Err(e) => { + warn!( + "Failed to get RPC connection to check registration, skipping: {:?}", + e + ); + } } } @@ -611,9 +774,18 @@ impl EpochManager { self.wait_for_report_work_phase(®istration_info).await?; } - // Report work + // Always send metrics report to channel for monitoring/testing + // This ensures metrics are captured even if we missed the report_work phase + self.send_work_report(®istration_info).await?; + + // Report work on-chain only if within the report_work phase if self.sync_slot().await? < phases.report_work.end { - self.report_work(®istration_info).await?; + self.report_work_onchain(®istration_info).await?; + } else { + info!( + "Skipping on-chain work report for epoch {} (report_work phase ended)", + registration_info.epoch.epoch + ); } // TODO: implement @@ -663,6 +835,16 @@ impl EpochManager { .into()); } + if slot < phases.registration.start { + let slots_to_wait = phases.registration.start.saturating_sub(slot); + info!( + "Registration for epoch {} hasn't started yet (current slot: {}, starts at: {}). Waiting {} slots...", + epoch, slot, phases.registration.start, slots_to_wait + ); + let wait_duration = Duration::from_millis(slots_to_wait * 400); + sleep(wait_duration).await; + } + for attempt in 0..max_retries { match self.register_for_epoch(epoch).await { Ok(registration_info) => return Ok(registration_info), @@ -749,6 +931,7 @@ impl EpochManager { &self.protocol_config, &self.config.payer_keypair, &self.config.derivation_pubkey, + Some(epoch), ) .await .with_context(|| { @@ -880,6 +1063,10 @@ impl EpochManager { active_phase_start_slot, waiting_secs); } + + self.prewarm_all_trees_during_wait(epoch_info, active_phase_start_slot) + .await; + wait_until_slot_reached(&mut *rpc, &self.slot_tracker, active_phase_start_slot).await?; let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority( @@ -893,6 +1080,18 @@ impl EpochManager { if let Some(registration) = existing_registration { if registration.total_epoch_weight.is_none() { + let current_slot = rpc.get_slot().await?; + if current_slot > epoch_info.epoch.phases.active.end { + info!( + "Skipping FinalizeRegistration for epoch {}: active phase ended (current slot: {}, end: {})", + epoch_info.epoch.epoch, current_slot, epoch_info.epoch.phases.active.end + ); + return Err(anyhow::anyhow!( + "Epoch {} active phase has ended, cannot finalize registration", + epoch_info.epoch.epoch + )); + } + // TODO: we can put this ix into every tx of the first batch of the current active phase let ix = create_finalize_registration_instruction( &self.config.payer_keypair.pubkey(), @@ -969,12 +1168,6 @@ impl EpochManager { self.sync_slot().await?; - let queue_poller = self.queue_poller.clone(); - - let self_arc = Arc::new(self.clone()); - let epoch_info_arc = Arc::new(epoch_info.clone()); - let mut handles: Vec>> = Vec::new(); - let trees_to_process: Vec<_> = epoch_info .trees .iter() @@ -982,80 +1175,15 @@ impl EpochManager { .cloned() .collect(); - let v2_trees: Vec<_> = trees_to_process - .iter() - .filter(|tree| { - matches!( - tree.tree_accounts.tree_type, - TreeType::StateV2 | TreeType::AddressV2 - ) - }) - .collect(); - - if queue_poller.is_some() { - info!("Using QueueInfoPoller for {} V2 trees", v2_trees.len()); - } - - let mut v2_receivers: std::collections::HashMap< - Pubkey, - mpsc::Receiver, - > = std::collections::HashMap::new(); - - if !v2_trees.is_empty() { - if let Some(ref poller) = queue_poller { - let registration_futures: Vec<_> = v2_trees - .iter() - .map(|tree| { - let poller = poller.clone(); - let tree_pubkey = tree.tree_accounts.merkle_tree; - async move { - let result = poller.ask(RegisterTree { tree_pubkey }).send().await; - (tree_pubkey, result) - } - }) - .collect(); - - let results = join_all(registration_futures).await; + let self_arc = Arc::new(self.clone()); + let epoch_info_arc = Arc::new(epoch_info.clone()); - for (tree_pubkey, result) in results { - match result { - Ok(rx) => { - v2_receivers.insert(tree_pubkey, rx); - } - Err(e) => { - error!( - "Failed to register V2 tree {} with queue poller: {:?}.", - tree_pubkey, e - ); - return Err(anyhow::anyhow!( - "Failed to register V2 tree {} with queue poller: {}. Cannot process without queue updates.", - tree_pubkey, e - )); - } - } - } - } else { - error!("No queue poller available for V2 trees."); - return Err(anyhow::anyhow!( - "No queue poller available for V2 trees. Cannot process without queue updates." - )); - } - } + let mut handles: Vec>> = Vec::with_capacity(trees_to_process.len()); for tree in trees_to_process { - let queue_update_rx = if matches!( - tree.tree_accounts.tree_type, - TreeType::StateV2 | TreeType::AddressV2 - ) { - v2_receivers.remove(&tree.tree_accounts.merkle_tree) - } else { - None - }; - - let has_channel = queue_update_rx.is_some(); info!( - "Creating thread for tree {} (type: {:?}, event: {})", - tree.tree_accounts.merkle_tree, tree.tree_accounts.tree_type, has_channel + "Creating thread for tree {} (type: {:?})", + tree.tree_accounts.merkle_tree, tree.tree_accounts.tree_type ); let self_clone = self_arc.clone(); @@ -1063,11 +1191,10 @@ impl EpochManager { let handle = tokio::spawn(async move { self_clone - .process_queue_v2( + .process_queue( &epoch_info_clone.epoch, &epoch_info_clone.forester_epoch_pda, - tree.clone(), - queue_update_rx, + tree, ) .await }); @@ -1111,76 +1238,13 @@ impl EpochManager { mut tree_schedule: TreeForesterSchedule, ) -> Result<()> { let mut current_slot = self.slot_tracker.estimated_current_slot(); - 'outer_slot_loop: while current_slot < epoch_info.phases.active.end { - let next_slot_to_process = tree_schedule - .slots - .iter_mut() - .enumerate() - .find_map(|(idx, opt_slot)| opt_slot.as_ref().map(|s| (idx, s.clone()))); - - if let Some((slot_idx, light_slot_details)) = next_slot_to_process { - match self - .process_light_slot( - epoch_info, - epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - ) - .await - { - Ok(_) => { - trace!( - "Successfully processed light slot {:?}", - light_slot_details.slot - ); - } - Err(e) => { - error!( - "Error processing light slot {:?}: {:?}. Skipping this slot.", - light_slot_details.slot, e - ); - } - } - tree_schedule.slots[slot_idx] = None; // Mark as attempted/processed - } else { - info!( - "No further eligible slots in schedule for tree {}", - tree_schedule.tree_accounts.merkle_tree - ); - break 'outer_slot_loop; - } - - current_slot = self.slot_tracker.estimated_current_slot(); - } - - info!( - "Exiting process_queue for tree {}", - tree_schedule.tree_accounts.merkle_tree - ); - Ok(()) - } - - #[instrument( - level = "debug", - skip(self, epoch_info, epoch_pda, tree_schedule, queue_update_rx), - fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch, - tree = %tree_schedule.tree_accounts.merkle_tree) - )] - pub async fn process_queue_v2( - &self, - epoch_info: &Epoch, - epoch_pda: &ForesterEpochPda, - mut tree_schedule: TreeForesterSchedule, - mut queue_update_rx: Option>, - ) -> Result<()> { - let mut current_slot = self.slot_tracker.estimated_current_slot(); let total_slots = tree_schedule.slots.len(); let eligible_slots = tree_schedule.slots.iter().filter(|s| s.is_some()).count(); let tree_type = tree_schedule.tree_accounts.tree_type; info!( - "process_queue_v2 tree={}, total_slots={}, eligible_slots={}, current_slot={}, active_phase_end={}", + "process_queue tree={}, total_slots={}, eligible_slots={}, current_slot={}, active_phase_end={}", tree_schedule.tree_accounts.merkle_tree, total_slots, eligible_slots, @@ -1207,29 +1271,17 @@ impl EpochManager { .await } TreeType::StateV2 | TreeType::AddressV2 => { - if let Some(ref mut rx) = queue_update_rx { - let consecutive_end = tree_schedule - .get_consecutive_eligibility_end(slot_idx) - .unwrap_or(light_slot_details.end_solana_slot); - self.process_light_slot_v2( - epoch_info, - epoch_pda, - &tree_schedule.tree_accounts, - &light_slot_details, - consecutive_end, - rx, - ) - .await - } else { - error!( - "No queue update channel available for V2 tree {}.", - tree_schedule.tree_accounts.merkle_tree - ); - Err(anyhow::anyhow!( - "No queue update channel for V2 tree {}", - tree_schedule.tree_accounts.merkle_tree - )) - } + let consecutive_end = tree_schedule + .get_consecutive_eligibility_end(slot_idx) + .unwrap_or(light_slot_details.end_solana_slot); + self.process_light_slot_v2( + epoch_info, + epoch_pda, + &tree_schedule.tree_accounts, + &light_slot_details, + consecutive_end, + ) + .await } }; @@ -1260,7 +1312,7 @@ impl EpochManager { } info!( - "Exiting process_queue_v2 for tree {}", + "Exiting process_queue for tree {}", tree_schedule.tree_accounts.merkle_tree ); Ok(()) @@ -1333,7 +1385,6 @@ impl EpochManager { forester_slot_details, forester_slot_details.end_solana_slot, estimated_slot, - None, ) .await { @@ -1376,7 +1427,7 @@ impl EpochManager { #[instrument( level = "debug", - skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details, consecutive_eligibility_end, queue_update_rx), + skip(self, epoch_info, epoch_pda, tree_accounts, forester_slot_details, consecutive_eligibility_end), fields(tree = %tree_accounts.merkle_tree) )] async fn process_light_slot_v2( @@ -1386,7 +1437,6 @@ impl EpochManager { tree_accounts: &TreeAccounts, forester_slot_details: &ForesterSlot, consecutive_eligibility_end: u64, - queue_update_rx: &mut mpsc::Receiver, ) -> Result<()> { info!( "Processing V2 light slot {} ({}-{}, consecutive_end={})", @@ -1396,6 +1446,8 @@ impl EpochManager { consecutive_eligibility_end ); + let tree_pubkey = tree_accounts.merkle_tree; + let mut rpc = self.rpc_pool.get_connection().await?; wait_until_slot_reached( &mut *rpc, @@ -1404,12 +1456,27 @@ impl EpochManager { ) .await?; - let tree_pubkey = tree_accounts.merkle_tree; + // Try to send any cached proofs first + let cached_send_start = Instant::now(); + if let Some(items_sent) = self + .try_send_cached_proofs(epoch_info, tree_accounts, consecutive_eligibility_end) + .await? + { + if items_sent > 0 { + let cached_send_duration = cached_send_start.elapsed(); + info!( + "Sent {} items from cache for tree {} in {:?}", + items_sent, tree_pubkey, cached_send_duration + ); + self.update_metrics_and_counts(epoch_info.epoch, items_sent, cached_send_duration) + .await; + } + } + let mut estimated_slot = self.slot_tracker.estimated_current_slot(); - let mut timeouts = 0u32; - const MAX_TIMEOUTS: u32 = 100; - const QUEUE_UPDATE_TIMEOUT: Duration = Duration::from_millis(150); + // Polling interval for checking queue + const POLL_INTERVAL: Duration = Duration::from_millis(200); 'inner_processing_loop: loop { if estimated_slot >= forester_slot_details.end_solana_slot { @@ -1440,75 +1507,37 @@ impl EpochManager { break 'inner_processing_loop; } - match tokio::time::timeout(QUEUE_UPDATE_TIMEOUT, queue_update_rx.recv()).await { - Ok(Some(update)) => { - timeouts = 0; - - if update.queue_size > 0 { - info!( - "V2 Queue update received for tree {}: {} items (type: {:?})", - tree_pubkey, update.queue_size, update.queue_type - ); - - let processing_start_time = Instant::now(); - match self - .dispatch_tree_processing( - epoch_info, - epoch_pda, - tree_accounts, - forester_slot_details, - consecutive_eligibility_end, - estimated_slot, - Some(&update), - ) - .await - { - Ok(count) => { - if count > 0 { - info!("V2 processed {} items", count); - self.update_metrics_and_counts( - epoch_info.epoch, - count, - processing_start_time.elapsed(), - ) - .await; - } - } - Err(e) => { - error!("V2 processing failed: {:?}", e); - } - } - } else { - trace!("V2 received empty queue update for tree {}", tree_pubkey); - } - } - Ok(None) => { - info!("Queue update channel closed for tree {}.", tree_pubkey); - break 'inner_processing_loop; - } - Err(_elapsed) => { - timeouts += 1; - - if timeouts >= MAX_TIMEOUTS { - error!( - "Queue poller has not sent updates for tree {} after {} timeouts ({} total).", - tree_pubkey, - timeouts, - timeouts as u64 * QUEUE_UPDATE_TIMEOUT.as_millis() as u64 - ); - return Err(anyhow::anyhow!( - "Queue poller health check failed: {} consecutive timeouts for tree {}", - timeouts, - tree_pubkey - )); + // Process directly - the processor fetches queue data from the indexer + let processing_start_time = Instant::now(); + match self + .dispatch_tree_processing( + epoch_info, + epoch_pda, + tree_accounts, + forester_slot_details, + consecutive_eligibility_end, + estimated_slot, + ) + .await + { + Ok(count) => { + if count > 0 { + info!("V2 processed {} items for tree {}", count, tree_pubkey); + self.update_metrics_and_counts( + epoch_info.epoch, + count, + processing_start_time.elapsed(), + ) + .await; } else { - trace!( - "Queue update timeout for tree {} (timeout #{}, continuing to check slot window)", - tree_pubkey, - timeouts - ); + // No items to process, wait before polling again + tokio::time::sleep(POLL_INTERVAL).await; } } + Err(e) => { + error!("V2 processing failed for tree {}: {:?}", tree_pubkey, e); + tokio::time::sleep(POLL_INTERVAL).await; + } } push_metrics(&self.config.external_services.pushgateway_url).await?; @@ -1577,7 +1606,6 @@ impl EpochManager { forester_slot_details: &ForesterSlot, consecutive_eligibility_end: u64, current_solana_slot: u64, - queue_update: Option<&QueueUpdateMessage>, ) -> Result { match tree_accounts.tree_type { TreeType::Unknown => { @@ -1599,13 +1627,13 @@ impl EpochManager { .await } TreeType::StateV2 | TreeType::AddressV2 => { - self.process_v2( - epoch_info, - tree_accounts, - queue_update, - consecutive_eligibility_end, - ) - .await + let result = self + .process_v2(epoch_info, tree_accounts, consecutive_eligibility_end) + .await?; + // Accumulate processing metrics for this epoch + self.add_processing_metrics(epoch_info.epoch, result.metrics) + .await; + Ok(result.items_processed) } } } @@ -1868,7 +1896,7 @@ impl EpochManager { epoch: epoch_info.epoch, merkle_tree: tree_accounts.merkle_tree, output_queue: tree_accounts.queue, - prover_config: ProverConfig { + prover_config: Arc::new(ProverConfig { append_url: self .config .external_services @@ -1898,7 +1926,7 @@ impl EpochManager { .external_services .prover_max_wait_time .unwrap_or(Duration::from_secs(600)), - }, + }), ops_cache: self.ops_cache.clone(), epoch_phases: epoch_info.phases.clone(), slot_tracker: self.slot_tracker.clone(), @@ -1906,53 +1934,115 @@ impl EpochManager { output_queue_hint, num_proof_workers: self.config.transaction_config.max_concurrent_batches, forester_eligibility_end_slot: Arc::new(AtomicU64::new(eligibility_end)), + address_lookup_tables: self.address_lookup_tables.clone(), + confirmation_max_attempts: self.config.transaction_config.confirmation_max_attempts, + confirmation_poll_interval: Duration::from_millis( + self.config.transaction_config.confirmation_poll_interval_ms, + ), } } - async fn get_or_create_state_supervisor( + async fn get_or_create_state_processor( &self, epoch_info: &Epoch, tree_accounts: &TreeAccounts, - ) -> Result>> { - use dashmap::mapref::entry::Entry; + ) -> Result>>> { + // First check if we already have a processor for this tree + // We REUSE processors across epochs to preserve cached state for optimistic processing + if let Some(entry) = self.state_processors.get(&tree_accounts.merkle_tree) { + let (stored_epoch, processor_ref) = entry.value(); + let processor_clone = processor_ref.clone(); + let old_epoch = *stored_epoch; + drop(entry); // Release read lock before any async operation + + if old_epoch != epoch_info.epoch { + // Update epoch in the map (processor is reused with its cached state) + debug!( + "Reusing StateBatchProcessor for tree {} across epoch transition ({} -> {})", + tree_accounts.merkle_tree, old_epoch, epoch_info.epoch + ); + self.state_processors.insert( + tree_accounts.merkle_tree, + (epoch_info.epoch, processor_clone.clone()), + ); + // Update the processor's epoch context and phases + processor_clone + .lock() + .await + .update_epoch(epoch_info.epoch, epoch_info.phases.clone()); + } + return Ok(processor_clone); + } - let entry = self.state_supervisors.entry(tree_accounts.merkle_tree); + // No existing processor - create new one + let batch_context = self.build_batch_context(epoch_info, tree_accounts, None, None, None); + let processor = Arc::new(Mutex::new( + QueueProcessor::new(batch_context, StateTreeStrategy).await?, + )); - match entry { - Entry::Occupied(mut occupied) => { - let (stored_epoch, supervisor_ref) = occupied.get(); - if *stored_epoch == epoch_info.epoch { - Ok(supervisor_ref.clone()) - } else { - info!( - "Removing stale StateSupervisor for tree {} (epoch {} -> {})", - tree_accounts.merkle_tree, *stored_epoch, epoch_info.epoch - ); - // Don't pass forester_slot - StateSupervisor is long-lived across forester slots, - // so it should use the global active phase end for safety checks - let batch_context = - self.build_batch_context(epoch_info, tree_accounts, None, None, None); - let supervisor = v2::state::StateSupervisor::spawn(batch_context); - info!( - "Created StateSupervisor actor for tree {} (epoch {})", - tree_accounts.merkle_tree, epoch_info.epoch - ); - occupied.insert((epoch_info.epoch, supervisor.clone())); - Ok(supervisor) - } + // Cache the zkp_batch_size for early filtering of queue updates + let batch_size = processor.lock().await.zkp_batch_size(); + self.zkp_batch_sizes + .insert(tree_accounts.merkle_tree, batch_size); + + // Insert the new processor (or get existing if another task beat us to it) + match self.state_processors.entry(tree_accounts.merkle_tree) { + Entry::Occupied(occupied) => { + // Another task already inserted - use theirs (they may have cached state) + Ok(occupied.get().1.clone()) } Entry::Vacant(vacant) => { - // Don't pass forester_slot - StateSupervisor is long-lived across forester slots, - // so it should use the global active phase end for safety checks - let batch_context = - self.build_batch_context(epoch_info, tree_accounts, None, None, None); - let supervisor = v2::state::StateSupervisor::spawn(batch_context); - info!( - "Created StateSupervisor actor for tree {} (epoch {})", - tree_accounts.merkle_tree, epoch_info.epoch + vacant.insert((epoch_info.epoch, processor.clone())); + Ok(processor) + } + } + } + + async fn get_or_create_address_processor( + &self, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + ) -> Result>>> { + if let Some(entry) = self.address_processors.get(&tree_accounts.merkle_tree) { + let (stored_epoch, processor_ref) = entry.value(); + let processor_clone = processor_ref.clone(); + let old_epoch = *stored_epoch; + drop(entry); + + if old_epoch != epoch_info.epoch { + debug!( + "Reusing AddressBatchProcessor for tree {} across epoch transition ({} -> {})", + tree_accounts.merkle_tree, old_epoch, epoch_info.epoch + ); + self.address_processors.insert( + tree_accounts.merkle_tree, + (epoch_info.epoch, processor_clone.clone()), ); - vacant.insert((epoch_info.epoch, supervisor.clone())); - Ok(supervisor) + processor_clone + .lock() + .await + .update_epoch(epoch_info.epoch, epoch_info.phases.clone()); + } + return Ok(processor_clone); + } + + // No existing processor - create new one + let batch_context = self.build_batch_context(epoch_info, tree_accounts, None, None, None); + let processor = Arc::new(Mutex::new( + QueueProcessor::new(batch_context, AddressTreeStrategy).await?, + )); + + // Cache the zkp_batch_size for early filtering of queue updates + let batch_size = processor.lock().await.zkp_batch_size(); + self.zkp_batch_sizes + .insert(tree_accounts.merkle_tree, batch_size); + + // Insert the new processor (or get existing if another task beat us to it) + match self.address_processors.entry(tree_accounts.merkle_tree) { + Entry::Occupied(occupied) => Ok(occupied.get().1.clone()), + Entry::Vacant(vacant) => { + vacant.insert((epoch_info.epoch, processor.clone())); + Ok(processor) } } } @@ -1961,70 +2051,115 @@ impl EpochManager { &self, epoch_info: &Epoch, tree_accounts: &TreeAccounts, - queue_update: Option<&QueueUpdateMessage>, consecutive_eligibility_end: u64, - ) -> Result { + ) -> Result { match tree_accounts.tree_type { TreeType::StateV2 => { - if let Some(update) = queue_update { - let supervisor = self - .get_or_create_state_supervisor(epoch_info, tree_accounts) - .await?; - - supervisor - .ask(v2::state::UpdateEligibility { - end_slot: consecutive_eligibility_end, - }) - .send() - .await - .map_err(|e| { - anyhow!( - "Failed to send UpdateEligibility to StateSupervisor for tree {}: {}", - tree_accounts.merkle_tree, - e - ) - })?; + let processor = self + .get_or_create_state_processor(epoch_info, tree_accounts) + .await?; - let work = v2::state::QueueWork { - queue_type: update.queue_type, - queue_size: update.queue_size, - }; + let cache = self + .proof_caches + .entry(tree_accounts.merkle_tree) + .or_insert_with(|| Arc::new(SharedProofCache::new(tree_accounts.merkle_tree))) + .clone(); - Ok(supervisor - .ask(v2::state::ProcessQueueUpdate { queue_work: work }) - .send() - .await - .map_err(|e| { - anyhow!( - "Failed to send message to StateSupervisor for tree {}: {}", + { + let mut proc = processor.lock().await; + proc.update_eligibility(consecutive_eligibility_end); + proc.set_proof_cache(cache); + } + + let mut proc = processor.lock().await; + match proc.process().await { + Ok(res) => Ok(res), + Err(e) => { + if is_v2_error(&e, V2Error::is_constraint) { + warn!( + "State processing hit constraint error for tree {}: {}. Dropping processor to flush cache.", tree_accounts.merkle_tree, e - ) - })?) - } else { - Ok(0) + ); + drop(proc); // Release lock before removing + self.state_processors.remove(&tree_accounts.merkle_tree); + self.proof_caches.remove(&tree_accounts.merkle_tree); + Err(e) + } else if is_v2_error(&e, V2Error::is_hashchain_mismatch) { + warn!( + "State processing hit hashchain mismatch for tree {}: {}. Clearing cache and retrying.", + tree_accounts.merkle_tree, + e + ); + proc.clear_cache().await; + Ok(ProcessingResult::default()) + } else { + warn!( + "Failed to process state queue for tree {}: {}. Will retry next tick without dropping processor.", + tree_accounts.merkle_tree, + e + ); + Ok(ProcessingResult::default()) + } + } } } TreeType::AddressV2 => { - let input_queue_hint = queue_update.map(|u| u.queue_size); - let batch_context = self.build_batch_context( - epoch_info, - tree_accounts, - input_queue_hint, - None, - Some(consecutive_eligibility_end), - ); + let processor = self + .get_or_create_address_processor(epoch_info, tree_accounts) + .await?; - process_batched_operations(batch_context, tree_accounts.tree_type) - .await - .map_err(|e| anyhow!("Failed to process V2 operations: {}", e)) + let cache = self + .proof_caches + .entry(tree_accounts.merkle_tree) + .or_insert_with(|| Arc::new(SharedProofCache::new(tree_accounts.merkle_tree))) + .clone(); + + { + let mut proc = processor.lock().await; + proc.update_eligibility(consecutive_eligibility_end); + proc.set_proof_cache(cache); + } + + let mut proc = processor.lock().await; + match proc.process().await { + Ok(res) => Ok(res), + Err(e) => { + if is_v2_error(&e, V2Error::is_constraint) { + warn!( + "Address processing hit constraint error for tree {}: {}. Dropping processor to flush cache.", + tree_accounts.merkle_tree, + e + ); + drop(proc); + self.address_processors.remove(&tree_accounts.merkle_tree); + self.proof_caches.remove(&tree_accounts.merkle_tree); + Err(e) + } else if is_v2_error(&e, V2Error::is_hashchain_mismatch) { + warn!( + "Address processing hit hashchain mismatch for tree {}: {}. Clearing cache and retrying.", + tree_accounts.merkle_tree, + e + ); + proc.clear_cache().await; + Ok(ProcessingResult::default()) + } else { + warn!( + "Failed to process address queue for tree {}: {}. Will retry next tick without dropping processor.", + tree_accounts.merkle_tree, + e + ); + Ok(ProcessingResult::default()) + } + } + } } _ => { warn!( "Unsupported tree type for V2 processing: {:?}", tree_accounts.tree_type ); - Ok(0) + Ok(ProcessingResult::default()) } } } @@ -2047,6 +2182,342 @@ impl EpochManager { } } + async fn prewarm_all_trees_during_wait( + &self, + epoch_info: &ForesterEpochInfo, + deadline_slot: u64, + ) { + let current_slot = self.slot_tracker.estimated_current_slot(); + let slots_until_active = deadline_slot.saturating_sub(current_slot); + + let trees = self.trees.lock().await; + let total_v2_state = trees + .iter() + .filter(|t| matches!(t.tree_type, TreeType::StateV2)) + .count(); + let v2_state_trees: Vec<_> = trees + .iter() + .filter(|t| { + matches!(t.tree_type, TreeType::StateV2) + && !should_skip_tree(&self.config, &t.tree_type) + }) + .cloned() + .collect(); + let skipped_count = total_v2_state - v2_state_trees.len(); + drop(trees); + + if v2_state_trees.is_empty() { + if skipped_count > 0 { + info!( + "No trees to pre-warm: {} StateV2 trees skipped by config", + skipped_count + ); + } + return; + } + + if slots_until_active < 15 { + info!( + "Skipping pre-warming: only {} slots until active phase, not enough time", + slots_until_active + ); + return; + } + + let prewarm_futures: Vec<_> = v2_state_trees + .iter() + .map(|tree_accounts| { + let tree_pubkey = tree_accounts.merkle_tree; + let epoch_info = epoch_info.clone(); + let tree_accounts = *tree_accounts; + let self_clone = self.clone(); + + async move { + let cache = self_clone + .proof_caches + .entry(tree_pubkey) + .or_insert_with(|| Arc::new(SharedProofCache::new(tree_pubkey))) + .clone(); + + let cache_len = cache.len().await; + if cache_len > 0 && !cache.is_warming().await { + let mut rpc = match self_clone.rpc_pool.get_connection().await { + Ok(r) => r, + Err(e) => { + warn!("Failed to get RPC for cache validation: {:?}", e); + return; + } + }; + if let Ok(current_root) = + self_clone.fetch_current_root(&mut *rpc, &tree_accounts).await + { + info!( + "Tree {} has {} cached proofs from previous epoch (root: {:?}), skipping pre-warm", + tree_pubkey, cache_len, ¤t_root[..4] + ); + return; + } + } + + let processor = match self_clone + .get_or_create_state_processor(&epoch_info.epoch, &tree_accounts) + .await + { + Ok(p) => p, + Err(e) => { + warn!( + "Failed to create processor for pre-warming tree {}: {:?}", + tree_pubkey, e + ); + return; + } + }; + + const PREWARM_MAX_BATCHES: usize = 4; + let mut p = processor.lock().await; + match p + .prewarm_from_indexer( + cache.clone(), + light_compressed_account::QueueType::OutputStateV2, + PREWARM_MAX_BATCHES, + ) + .await + { + Ok(result) => { + if result.items_processed > 0 { + info!( + "Pre-warmed {} items for tree {} during wait (metrics: {:?})", + result.items_processed, tree_pubkey, result.metrics + ); + self_clone + .add_processing_metrics(epoch_info.epoch.epoch, result.metrics) + .await; + } + } + Err(e) => { + debug!( + "Pre-warming from indexer failed for tree {}: {:?}", + tree_pubkey, e + ); + cache.clear().await; + } + } + } + }) + .collect(); + + let timeout_slots = slots_until_active.saturating_sub(5); + let timeout_duration = Duration::from_millis((timeout_slots * 400).min(30_000)); + + info!( + "Starting pre-warming for {} trees ({} skipped by config) with {}ms timeout", + v2_state_trees.len(), + skipped_count, + timeout_duration.as_millis() + ); + + match tokio::time::timeout(timeout_duration, futures::future::join_all(prewarm_futures)) + .await + { + Ok(_) => { + info!("Completed pre-warming for all trees"); + } + Err(_) => { + info!("Pre-warming timed out after {:?}", timeout_duration); + } + } + } + + async fn try_send_cached_proofs( + &self, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + consecutive_eligibility_end: u64, + ) -> Result> { + let tree_pubkey = tree_accounts.merkle_tree; + + // Check eligibility window before attempting to send cached proofs + let current_slot = self.slot_tracker.estimated_current_slot(); + if current_slot >= consecutive_eligibility_end { + debug!( + "Skipping cached proof send for tree {}: past eligibility window (slot {} >= {})", + tree_pubkey, current_slot, consecutive_eligibility_end + ); + return Ok(None); + } + + let cache = match self.proof_caches.get(&tree_pubkey) { + Some(c) => c.clone(), + None => return Ok(None), + }; + + if cache.is_warming().await { + debug!("Cache still warming for tree {}, skipping", tree_pubkey); + return Ok(None); + } + + let mut rpc = self.rpc_pool.get_connection().await?; + let current_root = match self.fetch_current_root(&mut *rpc, tree_accounts).await { + Ok(root) => root, + Err(e) => { + warn!( + "Failed to fetch current root for tree {}: {:?}", + tree_pubkey, e + ); + return Ok(None); + } + }; + + let cached_proofs = match cache.take_if_valid(¤t_root).await { + Some(proofs) => proofs, + None => { + debug!( + "No valid cached proofs for tree {} (root: {:?})", + tree_pubkey, + ¤t_root[..4] + ); + return Ok(None); + } + }; + + if cached_proofs.is_empty() { + return Ok(Some(0)); + } + + info!( + "Sending {} cached proofs for tree {} (root: {:?})", + cached_proofs.len(), + tree_pubkey, + ¤t_root[..4] + ); + + let items_sent = self + .send_cached_proofs_as_transactions(epoch_info, tree_accounts, cached_proofs) + .await?; + + Ok(Some(items_sent)) + } + + async fn fetch_current_root( + &self, + rpc: &mut impl Rpc, + tree_accounts: &TreeAccounts, + ) -> Result<[u8; 32]> { + use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; + + let mut account = rpc + .get_account(tree_accounts.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Tree account not found: {}", tree_accounts.merkle_tree))?; + + let tree = match tree_accounts.tree_type { + TreeType::StateV2 => BatchedMerkleTreeAccount::state_from_bytes( + &mut account.data, + &tree_accounts.merkle_tree.into(), + )?, + TreeType::AddressV2 => BatchedMerkleTreeAccount::address_from_bytes( + &mut account.data, + &tree_accounts.merkle_tree.into(), + )?, + _ => return Err(anyhow!("Unsupported tree type for root fetch")), + }; + + let root = tree.root_history.last().copied().unwrap_or([0u8; 32]); + Ok(root) + } + + async fn send_cached_proofs_as_transactions( + &self, + epoch_info: &Epoch, + tree_accounts: &TreeAccounts, + cached_proofs: Vec, + ) -> Result { + let mut total_items = 0; + let authority = self.config.payer_keypair.pubkey(); + let derivation = self.config.derivation_pubkey; + + const PROOFS_PER_TX: usize = 4; + for chunk in cached_proofs.chunks(PROOFS_PER_TX) { + let mut instructions = Vec::new(); + let mut chunk_items = 0; + + for proof in chunk { + match &proof.instruction { + BatchInstruction::Append(data) => { + for d in data { + let serialized = d + .try_to_vec() + .with_context(|| "Failed to serialize batch append payload")?; + instructions.push(create_batch_append_instruction( + authority, + derivation, + tree_accounts.merkle_tree, + tree_accounts.queue, + epoch_info.epoch, + serialized, + )); + } + } + BatchInstruction::Nullify(data) => { + for d in data { + let serialized = d + .try_to_vec() + .with_context(|| "Failed to serialize batch nullify payload")?; + instructions.push(create_batch_nullify_instruction( + authority, + derivation, + tree_accounts.merkle_tree, + epoch_info.epoch, + serialized, + )); + } + } + BatchInstruction::AddressAppend(data) => { + for d in data { + let serialized = d.try_to_vec().with_context(|| { + "Failed to serialize batch address append payload" + })?; + instructions.push(create_batch_update_address_tree_instruction( + authority, + derivation, + tree_accounts.merkle_tree, + epoch_info.epoch, + serialized, + )); + } + } + } + chunk_items += proof.items; + } + + if !instructions.is_empty() { + let mut rpc = self.rpc_pool.get_connection().await?; + match rpc + .create_and_send_transaction( + &instructions, + &authority, + &[&self.config.payer_keypair], + ) + .await + { + Ok(sig) => { + info!( + "Sent cached proofs tx: {} ({} instructions)", + sig, + instructions.len() + ); + total_items += chunk_items; + } + Err(e) => { + warn!("Failed to send cached proofs tx: {:?}", e); + } + } + } + } + + Ok(total_items) + } + async fn rollover_if_needed(&self, tree_account: &TreeAccounts) -> Result<()> { let mut rpc = self.rpc_pool.get_connection().await?; if is_tree_ready_for_rollover(&mut *rpc, tree_account.merkle_tree, tree_account.tree_type) @@ -2084,8 +2555,33 @@ impl EpochManager { #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch ))] - async fn report_work(&self, epoch_info: &ForesterEpochInfo) -> Result<()> { - info!("Reporting work"); + async fn send_work_report(&self, epoch_info: &ForesterEpochInfo) -> Result<()> { + let report = WorkReport { + epoch: epoch_info.epoch.epoch, + processed_items: self.get_processed_items_count(epoch_info.epoch.epoch).await, + metrics: self.get_processing_metrics(epoch_info.epoch.epoch).await, + }; + + info!( + "Sending work report: epoch={} items={} metrics={:?}", + report.epoch, report.processed_items, report.metrics + ); + + self.work_report_sender + .send(report) + .await + .map_err(|e| ChannelError::WorkReportSend { + epoch: report.epoch, + error: e.to_string(), + })?; + + Ok(()) + } + + #[instrument(level = "debug", skip(self, epoch_info), fields(forester = %self.config.payer_keypair.pubkey(), epoch = epoch_info.epoch.epoch + ))] + async fn report_work_onchain(&self, epoch_info: &ForesterEpochInfo) -> Result<()> { + info!("Reporting work on-chain"); let mut rpc = LightClient::new(LightClientConfig { url: self.config.external_services.rpc_url.to_string(), photon_url: self.config.external_services.indexer_url.clone(), @@ -2129,7 +2625,7 @@ impl EpochManager { .await { Ok(_) => { - info!("Work reported"); + info!("Work reported on-chain"); } Err(e) => { if e.to_string().contains("already been processed") { @@ -2155,19 +2651,6 @@ impl EpochManager { } } - let report = WorkReport { - epoch: epoch_info.epoch.epoch, - processed_items: self.get_processed_items_count(epoch_info.epoch.epoch).await, - }; - - self.work_report_sender - .send(report) - .await - .map_err(|e| ChannelError::WorkReportSend { - epoch: report.epoch, - error: e.to_string(), - })?; - Ok(()) } @@ -2230,11 +2713,6 @@ impl EpochManager { } Ok(()) } - - #[allow(dead_code)] - async fn claim(&self, _forester_epoch_info: ForesterEpochInfo) { - todo!() - } } fn calculate_remaining_time_or_default( @@ -2274,7 +2752,7 @@ pub async fn run_service( config: Arc, protocol_config: Arc, rpc_pool: Arc>, - shutdown: oneshot::Receiver<()>, + mut shutdown: oneshot::Receiver<()>, work_report_sender: mpsc::Sender, slot_tracker: Arc, tx_cache: Arc>, @@ -2304,43 +2782,128 @@ pub async fn run_service( let start_time = Instant::now(); let trees = { - let rpc = rpc_pool.get_connection().await?; - let mut fetched_trees = fetch_trees(&*rpc).await?; - if let Some(tree_id) = config.general_config.tree_id { - fetched_trees.retain(|tree| tree.merkle_tree == tree_id); - if fetched_trees.is_empty() { - error!("Specified tree {} not found", tree_id); - return Err(anyhow::anyhow!("Specified tree {} not found", tree_id)); + let max_attempts = 10; + let mut attempts = 0; + let mut delay = Duration::from_secs(2); + + loop { + tokio::select! { + biased; + _ = &mut shutdown => { + info!("Received shutdown signal during tree fetch. Stopping."); + return Ok(()); + } + result = rpc_pool.get_connection() => { + match result { + Ok(rpc) => { + tokio::select! { + biased; + _ = &mut shutdown => { + info!("Received shutdown signal during tree fetch. Stopping."); + return Ok(()); + } + fetch_result = fetch_trees(&*rpc) => { + match fetch_result { + Ok(mut fetched_trees) => { + if !config.general_config.tree_ids.is_empty() { + let tree_ids = &config.general_config.tree_ids; + fetched_trees.retain(|tree| tree_ids.contains(&tree.merkle_tree)); + if fetched_trees.is_empty() { + error!("None of the specified trees found: {:?}", tree_ids); + return Err(anyhow::anyhow!( + "None of the specified trees found: {:?}", + tree_ids + )); + } + info!("Processing only trees: {:?}", tree_ids); + } + break fetched_trees; + } + Err(e) => { + attempts += 1; + if attempts >= max_attempts { + return Err(anyhow::anyhow!( + "Failed to fetch trees after {} attempts: {:?}", + max_attempts, + e + )); + } + warn!( + "Failed to fetch trees (attempt {}/{}), retrying in {:?}: {:?}", + attempts, max_attempts, delay, e + ); + } + } + } + } + } + Err(e) => { + attempts += 1; + if attempts >= max_attempts { + return Err(anyhow::anyhow!( + "Failed to get RPC connection for trees after {} attempts: {:?}", + max_attempts, + e + )); + } + warn!( + "Failed to get RPC connection (attempt {}/{}), retrying in {:?}: {:?}", + attempts, max_attempts, delay, e + ); + } + } + } + } + + tokio::select! { + biased; + _ = &mut shutdown => { + info!("Received shutdown signal during retry wait. Stopping."); + return Ok(()); + } + _ = sleep(delay) => { + delay = std::cmp::min(delay * 2, Duration::from_secs(30)); + } } - info!("Processing only tree: {}", tree_id); } - fetched_trees }; trace!("Fetched initial trees: {:?}", trees); let (new_tree_sender, _) = broadcast::channel(100); - // Only run tree finder if not filtering by specific tree - let _tree_finder_handle = if config.general_config.tree_id.is_none() { - let mut tree_finder = TreeFinder::new( - rpc_pool.clone(), - trees.clone(), - new_tree_sender.clone(), - Duration::from_secs(config.general_config.tree_discovery_interval_seconds), - ); - - Some(tokio::spawn(async move { - if let Err(e) = tree_finder.run().await { - error!("Tree finder error: {:?}", e); - } - })) - } else { - info!("Tree discovery disabled when processing single tree"); - None - }; + if !config.general_config.tree_ids.is_empty() { + info!("Processing specific trees, tree discovery will be limited"); + } while retry_count < config.retry_config.max_retries { debug!("Creating EpochManager (attempt {})", retry_count + 1); + + let address_lookup_tables = { + if let Some(lut_address) = config.lookup_table_address { + let rpc = rpc_pool.get_connection().await?; + match load_lookup_table_async(&*rpc, lut_address).await { + Ok(lut) => { + info!( + "Loaded lookup table {} with {} addresses", + lut_address, + lut.addresses.len() + ); + Arc::new(vec![lut]) + } + Err(e) => { + debug!( + "Lookup table {} not available: {}. Using legacy transactions.", + lut_address, e + ); + Arc::new(Vec::new()) + } + } + } else { + debug!("No lookup table address configured. Using legacy transactions."); + Arc::new(Vec::new()) + } + }; + match EpochManager::new( config.clone(), protocol_config.clone(), @@ -2352,6 +2915,7 @@ pub async fn run_service( tx_cache.clone(), ops_cache.clone(), compressible_tracker.clone(), + address_lookup_tables, ) .await { @@ -2362,13 +2926,15 @@ pub async fn run_service( retry_count + 1 ); - return tokio::select! { + let result = tokio::select! { result = epoch_manager.run() => result, _ = shutdown => { info!("Received shutdown signal. Stopping the service."); Ok(()) } }; + + return result; } Err(e) => { warn!( @@ -2405,6 +2971,29 @@ pub async fn run_service( .await } +/// Async version of load_lookup_table that works with the Rpc trait +async fn load_lookup_table_async( + rpc: &R, + lookup_table_address: Pubkey, +) -> anyhow::Result { + use light_client::rpc::lut::AddressLookupTable; + + let account = rpc + .get_account(lookup_table_address) + .await? + .ok_or_else(|| { + anyhow::anyhow!("Lookup table account not found: {}", lookup_table_address) + })?; + + let address_lookup_table = AddressLookupTable::deserialize(&account.data) + .map_err(|e| anyhow::anyhow!("Failed to deserialize AddressLookupTable: {:?}", e))?; + + Ok(AddressLookupTableAccount { + key: lookup_table_address, + addresses: address_lookup_table.addresses.to_vec(), + }) +} + #[cfg(test)] mod tests { use light_client::rpc::RetryConfig; @@ -2454,9 +3043,10 @@ mod tests { skip_v1_address_trees: skip_v1_address, skip_v2_state_trees: skip_v2_state, skip_v2_address_trees: skip_v2_address, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, + queue_polling_mode: crate::cli::QueuePollingMode::Indexer, }, rpc_pool_config: Default::default(), registry_pubkey: Pubkey::default(), @@ -2465,6 +3055,7 @@ mod tests { address_tree_data: vec![], state_tree_data: vec![], compressible_config: None, + lookup_table_address: None, } } @@ -2596,9 +3187,31 @@ mod tests { let report = WorkReport { epoch: 42, processed_items: 100, + metrics: ProcessingMetrics { + append: CircuitMetrics { + circuit_inputs_duration: std::time::Duration::from_secs(1), + proof_generation_duration: std::time::Duration::from_secs(3), + round_trip_duration: std::time::Duration::from_secs(10), + }, + nullify: CircuitMetrics { + circuit_inputs_duration: std::time::Duration::from_secs(1), + proof_generation_duration: std::time::Duration::from_secs(2), + round_trip_duration: std::time::Duration::from_secs(8), + }, + address_append: CircuitMetrics { + circuit_inputs_duration: std::time::Duration::from_secs(1), + proof_generation_duration: std::time::Duration::from_secs(2), + round_trip_duration: std::time::Duration::from_secs(9), + }, + tx_sending_duration: std::time::Duration::ZERO, + }, }; assert_eq!(report.epoch, 42); assert_eq!(report.processed_items, 100); + assert_eq!(report.metrics.total().as_secs(), 10); + assert_eq!(report.metrics.total_circuit_inputs().as_secs(), 3); + assert_eq!(report.metrics.total_proof_generation().as_secs(), 7); + assert_eq!(report.metrics.total_round_trip().as_secs(), 27); } } diff --git a/forester/src/errors.rs b/forester/src/errors.rs index 79c6cfcd45..6d652a270c 100644 --- a/forester/src/errors.rs +++ b/forester/src/errors.rs @@ -102,17 +102,14 @@ pub enum ConfigError { #[error("Missing required field: {field}")] MissingField { field: &'static str }, - #[error("Invalid keypair data: {0}")] - InvalidKeypair(String), - - #[error("Invalid pubkey: {field} - {error}")] - InvalidPubkey { field: &'static str, error: String }, - - #[error("Invalid derivation: {reason}")] - InvalidDerivation { reason: String }, - - #[error("JSON parsing error: {field} - {error}")] + #[error("JSON parsing error for {field}: {error}")] JsonParse { field: &'static str, error: String }, + + #[error("Invalid {field}: {}", .invalid_values.join(", "))] + InvalidArguments { + field: &'static str, + invalid_values: Vec, + }, } #[derive(Error, Debug)] diff --git a/forester/src/forester_status.rs b/forester/src/forester_status.rs index affc04f4e4..9c6b4c74eb 100644 --- a/forester/src/forester_status.rs +++ b/forester/src/forester_status.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use anchor_lang::{AccountDeserialize, Discriminator}; +use anyhow::Context; use forester_utils::forester_epoch::{get_epoch_phases, TreeAccounts}; use itertools::Itertools; use light_client::rpc::{LightClient, LightClientConfig, Rpc}; @@ -28,45 +29,47 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { ); let registry_accounts = client .get_program_accounts(&light_registry::ID) - .expect("Failed to fetch accounts for registry program."); + .context("Failed to fetch accounts for registry program")?; let mut forester_epoch_pdas = vec![]; let mut epoch_pdas = vec![]; let mut protocol_config_pdas = vec![]; for (_, account) in registry_accounts { - match account.data()[0..8].try_into()? { - ForesterEpochPda::DISCRIMINATOR => { - let forester_epoch_pda = - ForesterEpochPda::try_deserialize_unchecked(&mut account.data()) - .expect("Failed to deserialize ForesterEpochPda"); - forester_epoch_pdas.push(forester_epoch_pda); - } - EpochPda::DISCRIMINATOR => { - let epoch_pda = EpochPda::try_deserialize_unchecked(&mut account.data()) - .expect("Failed to deserialize EpochPda"); - epoch_pdas.push(epoch_pda); - } - ProtocolConfigPda::DISCRIMINATOR => { - let protocol_config_pda = - ProtocolConfigPda::try_deserialize_unchecked(&mut account.data()) - .expect("Failed to deserialize ProtocolConfigPda"); - protocol_config_pdas.push(protocol_config_pda); - } - _ => (), + let discriminator_bytes = match account.data().get(0..8) { + Some(d) => d, + None => continue, + }; + + if discriminator_bytes == ForesterEpochPda::DISCRIMINATOR { + let mut data: &[u8] = account.data(); + let forester_epoch_pda = ForesterEpochPda::try_deserialize_unchecked(&mut data) + .context("Failed to deserialize ForesterEpochPda")?; + forester_epoch_pdas.push(forester_epoch_pda); + } else if discriminator_bytes == EpochPda::DISCRIMINATOR { + let mut data: &[u8] = account.data(); + let epoch_pda = EpochPda::try_deserialize_unchecked(&mut data) + .context("Failed to deserialize EpochPda")?; + epoch_pdas.push(epoch_pda); + } else if discriminator_bytes == ProtocolConfigPda::DISCRIMINATOR { + let mut data: &[u8] = account.data(); + let protocol_config_pda = ProtocolConfigPda::try_deserialize_unchecked(&mut data) + .context("Failed to deserialize ProtocolConfigPda")?; + protocol_config_pdas.push(protocol_config_pda); } } forester_epoch_pdas.sort_by(|a, b| a.epoch.cmp(&b.epoch)); epoch_pdas.sort_by(|a, b| a.epoch.cmp(&b.epoch)); - let slot = client.get_slot().expect("Failed to fetch slot."); + let slot = client.get_slot().context("Failed to fetch slot")?; + + let protocol_config_pda = protocol_config_pdas + .first() + .cloned() + .context("No ProtocolConfigPda found in registry program accounts")?; println!("Current Solana Slot: {}", slot); - let current_active_epoch = protocol_config_pdas[0] - .config - .get_current_active_epoch(slot)?; - let current_registration_epoch = protocol_config_pdas[0] - .config - .get_latest_register_epoch(slot)?; + let current_active_epoch = protocol_config_pda.config.get_current_active_epoch(slot)?; + let current_registration_epoch = protocol_config_pda.config.get_latest_register_epoch(slot)?; println!("Current active epoch: {:?}", current_active_epoch); println!( @@ -104,26 +107,26 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { ); println!( "current active epoch progress {:?} / {}", - protocol_config_pdas[0] + protocol_config_pda .config .get_current_active_epoch_progress(slot), - protocol_config_pdas[0].config.active_phase_length + protocol_config_pda.config.active_phase_length ); println!( "current active epoch progress {:.2?}%", - protocol_config_pdas[0] + protocol_config_pda .config .get_current_active_epoch_progress(slot) as f64 - / protocol_config_pdas[0].config.active_phase_length as f64 + / protocol_config_pda.config.active_phase_length as f64 * 100f64 ); println!("Hours until next epoch : {:?} hours", { // slotduration is 460ms and 1000ms is 1 second and 3600 seconds is 1 hour - protocol_config_pdas[0] + protocol_config_pda .config .active_phase_length .saturating_sub( - protocol_config_pdas[0] + protocol_config_pda .config .get_current_active_epoch_progress(slot), ) @@ -131,11 +134,11 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { / 1000 / 3600 }); - let slots_until_next_registration = protocol_config_pdas[0] + let slots_until_next_registration = protocol_config_pda .config .registration_phase_length .saturating_sub( - protocol_config_pdas[0] + protocol_config_pda .config .get_current_active_epoch_progress(slot), ); @@ -160,7 +163,7 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { } } if args.protocol_config { - println!("protocol config: {:?}", protocol_config_pdas[0]); + println!("protocol config: {:?}", protocol_config_pda); } let config = Arc::new(ForesterConfig::new_for_status(args)?); @@ -226,8 +229,6 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { .iter() .find(|pda| pda.epoch == current_active_epoch); - let protocol_config = protocol_config_pdas[0].clone(); - print_tree_schedule_by_forester( slot, current_active_epoch, @@ -235,7 +236,7 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { tree.merkle_tree, tree.queue, current_epoch_pda_entry, - &protocol_config, + &protocol_config_pda, ); } } @@ -251,8 +252,6 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { .iter() .find(|pda| pda.epoch == current_active_epoch); - let protocol_config = protocol_config_pdas[0].clone(); - // Filter out rolled-over trees let active_trees: Vec = trees.iter().filter(|t| !t.is_rolledover).cloned().collect(); @@ -264,7 +263,7 @@ pub async fn fetch_forester_status(args: &StatusArgs) -> crate::Result<()> { active_epoch_foresters, &active_trees, current_epoch_pda_entry, - &protocol_config, + &protocol_config_pda, ); } else { println!("No active foresters found for the current epoch."); @@ -286,17 +285,20 @@ fn print_current_forester_assignments( if let Some(_current_epoch_pda) = current_epoch_pda_entry { if active_epoch_foresters.is_empty() { println!( - "ERROR: No foresters registered for active epoch {}", + "error: no foresters registered for active epoch {}", current_active_epoch ); return; } - let total_epoch_weight = match active_epoch_foresters[0].total_epoch_weight { + let total_epoch_weight = match active_epoch_foresters + .first() + .and_then(|pda| pda.total_epoch_weight) + { Some(w) if w > 0 => w, _ => { println!( - "ERROR: Registration not finalized (total_epoch_weight is None or 0) for epoch {}.", + "error: registration not finalized (total_epoch_weight is none or 0) for epoch {}", current_active_epoch ); return; @@ -314,7 +316,7 @@ fn print_current_forester_assignments( } if protocol_config.config.slot_length == 0 { - println!("ERROR: ProtocolConfig slot_length is zero. Cannot calculate light slots."); + println!("error: protocol config slot_length is zero; cannot calculate light slots"); return; } @@ -350,7 +352,7 @@ fn print_current_forester_assignments( Ok(idx) => idx, Err(e) => { println!( - "{:12}\t\t{}\tERROR: {:?}", + "{:12}\t\t{}\terror: {:?}", tree.tree_type, tree.merkle_tree, e ); continue; @@ -375,7 +377,7 @@ fn print_current_forester_assignments( } } else { println!( - "ERROR: Could not find EpochPda for active epoch {}. Cannot determine forester assignments.", + "error: could not find EpochPda for active epoch {}; cannot determine forester assignments", current_active_epoch ); } @@ -393,15 +395,18 @@ fn print_tree_schedule_by_forester( if let Some(_current_epoch_pda) = current_epoch_pda_entry { if active_epoch_foresters.is_empty() { println!( - "ERROR: No foresters registered for tree {} in active epoch {}", + "error: no foresters registered for tree {} in active epoch {}", tree, current_active_epoch ); } else { - let total_epoch_weight = match active_epoch_foresters[0].total_epoch_weight { + let total_epoch_weight = match active_epoch_foresters + .first() + .and_then(|pda| pda.total_epoch_weight) + { Some(w) if w > 0 => w, _ => { println!( - "ERROR: Registration not finalized (total_epoch_weight is None or 0) for epoch {}. Cannot check assignments.", + "error: registration not finalized (total_epoch_weight is none or 0) for epoch {}; cannot check assignments", current_active_epoch ); 0 @@ -443,7 +448,7 @@ fn print_tree_schedule_by_forester( epoch_phases.active.length() / protocol_config.config.slot_length } else { println!( - "ERROR: ProtocolConfig slot_length is zero. Cannot calculate light slots." + "error: protocol config slot_length is zero; cannot calculate light slots" ); 0 }; @@ -469,7 +474,7 @@ fn print_tree_schedule_by_forester( Ok(idx) => idx, Err(e) => { println!( - "ERROR calculating eligible index for light slot {}: {:?}", + "error calculating eligible index for light slot {}: {:?}", i, e ); all_slots_checked = false; @@ -505,7 +510,13 @@ fn print_tree_schedule_by_forester( let current_light_slot_index = if slot >= epoch_phases.active.start && slot < epoch_phases.active.end { - match active_epoch_foresters[0].get_current_light_slot(slot) { + match active_epoch_foresters + .first() + .context("No foresters registered for active epoch") + .and_then(|pda| { + pda.get_current_light_slot(slot) + .context("get_current_light_slot failed") + }) { Ok(ls) => ls, Err(e) => { println!("WARN: Could not calculate current light slot from PDA (using approximation): {:?}", e); @@ -582,7 +593,7 @@ fn print_tree_schedule_by_forester( } } else { println!( - "Check FAILED: Tree {} is missing forester assignment starting at least at light slot index {} in epoch {}.", + "check failed: tree {} is missing forester assignment starting at least at light slot index {} in epoch {}", tree, first_missing_slot, current_active_epoch ); } @@ -590,7 +601,7 @@ fn print_tree_schedule_by_forester( } } else if current_epoch_pda_entry.is_none() { println!( - "ERROR: Could not find EpochPda for active epoch {}. Cannot check forester assignments.", + "error: could not find EpochPda for active epoch {}; cannot check forester assignments", current_active_epoch ); } diff --git a/forester/src/health_check.rs b/forester/src/health_check.rs index db784f9ee1..8338861e0d 100644 --- a/forester/src/health_check.rs +++ b/forester/src/health_check.rs @@ -122,9 +122,9 @@ pub async fn run_health_check(args: &HealthArgs) -> Result } if !all_passed { - println!("\nHealth check FAILED"); + println!("\nHealth check failed"); } else { - println!("\nHealth check PASSED"); + println!("\nHealth check passed"); } } } @@ -245,7 +245,18 @@ async fn check_epoch_registration( ); } - Pubkey::new_from_array(bytes.try_into().unwrap()) + let bytes: [u8; 32] = match bytes.try_into() { + Ok(b) => b, + Err(_) => { + return HealthCheckResult::new( + "registration", + false, + "Derivation pubkey must be 32 bytes".to_string(), + start.elapsed().as_millis() as u64, + ); + } + }; + Pubkey::new_from_array(bytes) } else { match Pubkey::from_str(derivation) { Ok(pk) => pk, @@ -287,7 +298,17 @@ async fn check_epoch_registration( } }; - let protocol_config = get_protocol_config(&mut *rpc).await; + let protocol_config = match get_protocol_config(&mut *rpc).await { + Ok(cfg) => cfg, + Err(e) => { + return HealthCheckResult::new( + "registration", + false, + format!("Failed to fetch protocol config: {}", e), + start.elapsed().as_millis() as u64, + ); + } + }; let current_epoch = protocol_config.get_current_epoch(slot); let forester_epoch_pda_pubkey = get_forester_epoch_pda_from_authority(&derivation_pubkey, current_epoch).0; @@ -297,14 +318,13 @@ async fn check_epoch_registration( match registration_result { Ok(Some(pda)) => { - if pda.total_epoch_weight.is_some() { + if let Some(weight) = pda.total_epoch_weight { HealthCheckResult::new( "registration", true, format!( "Forester registered for epoch {} with weight {}", - current_epoch, - pda.total_epoch_weight.unwrap() + current_epoch, weight ), start.elapsed().as_millis() as u64, ) diff --git a/forester/src/lib.rs b/forester/src/lib.rs index de5254d0bb..e44e07f36c 100644 --- a/forester/src/lib.rs +++ b/forester/src/lib.rs @@ -10,7 +10,6 @@ pub mod health_check; pub mod helius_priority_fee_types; pub mod metrics; pub mod pagerduty; -pub mod polling; pub mod processor; pub mod pubsub_client; pub mod queue_helpers; @@ -19,7 +18,6 @@ pub mod slot_tracker; pub mod smart_transaction; pub mod telemetry; pub mod tree_data_sync; -pub mod tree_finder; pub mod utils; use std::{sync::Arc, time::Duration}; @@ -60,8 +58,7 @@ pub async fn run_queue_info( commitment_config: None, fetch_active_tree: false, }) - .await - .unwrap(); + .await?; let trees: Vec<_> = trees .iter() .filter(|t| t.tree_type == queue_type && !t.is_rolledover) @@ -182,19 +179,14 @@ pub async fn run_pipeline( builder = builder.send_tx_rate_limiter(limiter); } - let rpc_pool = builder.build().await?; - - let protocol_config = { - let mut rpc = rpc_pool.get_connection().await?; - get_protocol_config(&mut *rpc).await - }; - - let arc_pool = Arc::new(rpc_pool); + let arc_pool = Arc::new(builder.build().await?); let arc_pool_clone = Arc::clone(&arc_pool); - let slot = { - let rpc = arc_pool.get_connection().await?; - rpc.get_slot().await? + let (protocol_config, slot) = { + let mut rpc = arc_pool.get_connection().await?; + let protocol_config = get_protocol_config(&mut *rpc).await?; + let slot = rpc.get_slot().await?; + (protocol_config, slot) }; let slot_tracker = SlotTracker::new( slot, @@ -202,12 +194,21 @@ pub async fn run_pipeline( ); let arc_slot_tracker = Arc::new(slot_tracker); let arc_slot_tracker_clone = arc_slot_tracker.clone(); - tokio::spawn(async move { - let mut rpc = arc_pool_clone - .get_connection() - .await - .expect("Failed to get RPC connection"); - SlotTracker::run(arc_slot_tracker_clone, &mut *rpc).await; + let slot_tracker_handle = tokio::spawn(async move { + loop { + match arc_pool_clone.get_connection().await { + Ok(mut rpc) => { + SlotTracker::run(arc_slot_tracker_clone.clone(), &mut *rpc).await; + // If SlotTracker::run returns, the connection likely failed + tracing::warn!("SlotTracker connection lost, reconnecting..."); + } + Err(e) => { + tracing::error!("Failed to get RPC connection for SlotTracker: {:?}", e); + } + } + // Wait before retrying + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + } }); let tx_cache = Arc::new(Mutex::new(ProcessedHashCache::new( @@ -218,7 +219,6 @@ pub async fn run_pipeline( config.transaction_config.ops_cache_ttl_seconds, ))); - // Start compressible subscriber if enabled and get tracker let compressible_tracker = if let Some(compressible_config) = &config.compressible_config { if let Some(shutdown_rx) = shutdown_compressible { let tracker = Arc::new(compressible::CompressibleAccountTracker::new()); @@ -282,7 +282,7 @@ pub async fn run_pipeline( }; debug!("Starting Forester pipeline"); - run_service( + let result = run_service( config, Arc::new(protocol_config), arc_pool, @@ -293,6 +293,11 @@ pub async fn run_pipeline( ops_cache, compressible_tracker, ) - .await?; - Ok(()) + .await; + + // Stop the SlotTracker task to prevent panic during shutdown + tracing::debug!("Stopping SlotTracker task"); + slot_tracker_handle.abort(); + + result } diff --git a/forester/src/main.rs b/forester/src/main.rs index c157ac54c5..48bbb49300 100644 --- a/forester/src/main.rs +++ b/forester/src/main.rs @@ -19,6 +19,38 @@ use tokio::{ }; use tracing::debug; +/// Spawns a task that handles graceful shutdown on Ctrl+C. +/// +/// First Ctrl+C triggers graceful shutdown by sending to `service_sender` +/// and calling the optional `additional_shutdown` closure. +/// Second Ctrl+C forces immediate exit. +fn spawn_shutdown_handler(service_sender: oneshot::Sender<()>, additional_shutdown: Option) +where + F: FnOnce() + Send + 'static, +{ + tokio::spawn(async move { + if let Err(e) = ctrl_c().await { + tracing::error!("Failed to listen for Ctrl+C: {}", e); + return; + } + tracing::info!("Received Ctrl+C, initiating graceful shutdown..."); + if service_sender.send(()).is_err() { + tracing::warn!("Shutdown signal to service already sent or receiver dropped"); + } + if let Some(shutdown_fn) = additional_shutdown { + shutdown_fn(); + } + + // Wait for second Ctrl+C to force exit + if let Err(e) = ctrl_c().await { + tracing::warn!("Failed to listen for second Ctrl+C (forcing exit): {}", e); + std::process::exit(1); + } + tracing::warn!("Received second Ctrl+C, forcing exit!"); + std::process::exit(1); + }); +} + #[tokio::main] #[allow(clippy::result_large_err)] async fn main() -> Result<(), ForesterError> { @@ -44,25 +76,19 @@ async fn main() -> Result<(), ForesterError> { tokio::sync::broadcast::channel(1); let (shutdown_sender_bootstrap, shutdown_receiver_bootstrap) = oneshot::channel(); - tokio::spawn(async move { - ctrl_c().await.expect("Failed to listen for Ctrl+C"); - shutdown_sender_service - .send(()) - .expect("Failed to send shutdown signal to service"); - let _ = shutdown_sender_compressible.send(()); - let _ = shutdown_sender_bootstrap.send(()); - }); + spawn_shutdown_handler( + shutdown_sender_service, + Some(move || { + let _ = shutdown_sender_compressible.send(()); + let _ = shutdown_sender_bootstrap.send(()); + }), + ); ( Some(shutdown_receiver_compressible), Some(shutdown_receiver_bootstrap), ) } else { - tokio::spawn(async move { - ctrl_c().await.expect("Failed to listen for Ctrl+C"); - shutdown_sender_service - .send(()) - .expect("Failed to send shutdown signal to service"); - }); + spawn_shutdown_handler::(shutdown_sender_service, None); (None, None) }; diff --git a/forester/src/metrics.rs b/forester/src/metrics.rs index 823fc97aa5..9b01c33cee 100644 --- a/forester/src/metrics.rs +++ b/forester/src/metrics.rs @@ -17,12 +17,18 @@ lazy_static! { prometheus::opts!("queue_length", "Length of the queue"), &["tree_type", "tree_pubkey"] ) - .expect("metric can be created"); + .unwrap_or_else(|e| { + error!("Failed to create metric QUEUE_LENGTH: {:?}", e); + std::process::exit(1); + }); pub static ref LAST_RUN_TIMESTAMP: IntGauge = IntGauge::new( "forester_last_run_timestamp", "Timestamp of the last Forester run" ) - .expect("metric can be created"); + .unwrap_or_else(|e| { + error!("Failed to create metric LAST_RUN_TIMESTAMP: {:?}", e); + std::process::exit(1); + }); pub static ref TRANSACTIONS_PROCESSED: IntCounterVec = IntCounterVec::new( prometheus::opts!( "forester_transactions_processed_total", @@ -30,7 +36,10 @@ lazy_static! { ), &["epoch"] ) - .expect("metric can be created"); + .unwrap_or_else(|e| { + error!("Failed to create metric TRANSACTIONS_PROCESSED: {:?}", e); + std::process::exit(1); + }); pub static ref TRANSACTION_TIMESTAMP: GaugeVec = GaugeVec::new( prometheus::opts!( "forester_transaction_timestamp", @@ -38,7 +47,10 @@ lazy_static! { ), &["epoch"] ) - .expect("metric can be created"); + .unwrap_or_else(|e| { + error!("Failed to create metric TRANSACTION_TIMESTAMP: {:?}", e); + std::process::exit(1); + }); pub static ref TRANSACTION_RATE: GaugeVec = GaugeVec::new( prometheus::opts!( "forester_transaction_rate", @@ -46,7 +58,10 @@ lazy_static! { ), &["epoch"] ) - .expect("metric can be created"); + .unwrap_or_else(|e| { + error!("Failed to create metric TRANSACTION_RATE: {:?}", e); + std::process::exit(1); + }); pub static ref FORESTER_SOL_BALANCE: GaugeVec = GaugeVec::new( prometheus::opts!( "forester_sol_balance", @@ -54,12 +69,18 @@ lazy_static! { ), &["pubkey"] ) - .expect("metric can be created"); + .unwrap_or_else(|e| { + error!("Failed to create metric FORESTER_SOL_BALANCE: {:?}", e); + std::process::exit(1); + }); pub static ref REGISTERED_FORESTERS: GaugeVec = GaugeVec::new( prometheus::opts!("registered_foresters", "Foresters registered per epoch"), &["epoch", "authority"] ) - .expect("metric can be created"); + .unwrap_or_else(|e| { + error!("Failed to create metric REGISTERED_FORESTERS: {:?}", e); + std::process::exit(1); + }); static ref METRIC_UPDATES: Mutex> = Mutex::new(Vec::new()); } @@ -67,43 +88,49 @@ lazy_static! { static INIT: Once = Once::new(); pub fn register_metrics() { INIT.call_once(|| { - REGISTRY - .register(Box::new(QUEUE_LENGTH.clone())) - .expect("collector can be registered"); - REGISTRY - .register(Box::new(LAST_RUN_TIMESTAMP.clone())) - .expect("collector can be registered"); - REGISTRY - .register(Box::new(TRANSACTIONS_PROCESSED.clone())) - .expect("collector can be registered"); - REGISTRY - .register(Box::new(TRANSACTION_TIMESTAMP.clone())) - .expect("collector can be registered"); - REGISTRY - .register(Box::new(TRANSACTION_RATE.clone())) - .expect("collector can be registered"); - REGISTRY - .register(Box::new(FORESTER_SOL_BALANCE.clone())) - .expect("collector can be registered"); - REGISTRY - .register(Box::new(REGISTERED_FORESTERS.clone())) - .expect("collector can be registered"); + if let Err(e) = REGISTRY.register(Box::new(QUEUE_LENGTH.clone())) { + error!("Failed to register metric QUEUE_LENGTH: {:?}", e); + } + if let Err(e) = REGISTRY.register(Box::new(LAST_RUN_TIMESTAMP.clone())) { + error!("Failed to register metric LAST_RUN_TIMESTAMP: {:?}", e); + } + if let Err(e) = REGISTRY.register(Box::new(TRANSACTIONS_PROCESSED.clone())) { + error!("Failed to register metric TRANSACTIONS_PROCESSED: {:?}", e); + } + if let Err(e) = REGISTRY.register(Box::new(TRANSACTION_TIMESTAMP.clone())) { + error!("Failed to register metric TRANSACTION_TIMESTAMP: {:?}", e); + } + if let Err(e) = REGISTRY.register(Box::new(TRANSACTION_RATE.clone())) { + error!("Failed to register metric TRANSACTION_RATE: {:?}", e); + } + if let Err(e) = REGISTRY.register(Box::new(FORESTER_SOL_BALANCE.clone())) { + error!("Failed to register metric FORESTER_SOL_BALANCE: {:?}", e); + } + if let Err(e) = REGISTRY.register(Box::new(REGISTERED_FORESTERS.clone())) { + error!("Failed to register metric REGISTERED_FORESTERS: {:?}", e); + } }); } pub fn update_last_run_timestamp() { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs() as i64; + .map(|d| d.as_secs() as i64) + .unwrap_or_else(|e| { + error!("Failed to compute last run timestamp: {}", e); + 0 + }); LAST_RUN_TIMESTAMP.set(now); } pub fn update_transactions_processed(epoch: u64, count: usize, duration: std::time::Duration) { let now = SystemTime::now() .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs_f64(); + .map(|d| d.as_secs_f64()) + .unwrap_or_else(|e| { + error!("Failed to compute transaction timestamp: {}", e); + 0.0 + }); TRANSACTIONS_PROCESSED .with_label_values(&[&epoch.to_string()]) @@ -146,8 +173,10 @@ pub fn update_forester_sol_balance(pubkey: &str, balance: f64) { } pub fn update_registered_foresters(epoch: u64, authority: &str) { + let epoch_str = epoch.to_string(); + let authority_str = authority.to_string(); REGISTERED_FORESTERS - .with_label_values(&[&epoch.to_string(), authority]) + .with_label_values(&[epoch_str.as_str(), authority_str.as_str()]) .set(1.0); } diff --git a/forester/src/polling/mod.rs b/forester/src/polling/mod.rs deleted file mode 100644 index ac64bba8d8..0000000000 --- a/forester/src/polling/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod queue_poller; - -pub use queue_poller::{ - QueueInfoPoller, QueueUpdateMessage, RegisterTree, RegisteredTreeCount, UnregisterTree, -}; diff --git a/forester/src/polling/queue_poller.rs b/forester/src/polling/queue_poller.rs deleted file mode 100644 index 42bd374226..0000000000 --- a/forester/src/polling/queue_poller.rs +++ /dev/null @@ -1,284 +0,0 @@ -use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, - time::Duration, -}; - -use anyhow::Result; -use kameo::{ - actor::{ActorRef, WeakActorRef}, - error::ActorStopReason, - message::Message, - Actor, -}; -use light_client::indexer::{photon_indexer::PhotonIndexer, Indexer}; -use light_compressed_account::QueueType; -use solana_sdk::pubkey::Pubkey; -use tokio::sync::mpsc; -use tracing::{debug, error, info, trace, warn}; - -const POLLING_INTERVAL_SECS: u64 = 1; - -#[derive(Debug, Clone)] -pub struct QueueUpdateMessage { - pub tree: Pubkey, - pub queue: Pubkey, - pub queue_type: QueueType, - pub queue_size: u64, - pub slot: u64, -} - -pub struct QueueInfoPoller { - indexer: PhotonIndexer, - tree_notifiers: HashMap>, - polling_active: Arc, -} - -impl Actor for QueueInfoPoller { - type Args = Self; - type Error = anyhow::Error; - - async fn on_start(state: Self::Args, actor_ref: ActorRef) -> Result { - info!("QueueInfoPoller actor starting"); - - let polling_active = state.polling_active.clone(); - tokio::spawn(async move { - polling_loop(actor_ref, polling_active).await; - }); - - Ok(state) - } - - async fn on_stop( - &mut self, - _actor_ref: WeakActorRef, - _reason: ActorStopReason, - ) -> Result<()> { - info!("QueueInfoPoller actor stopping"); - // Use Release ordering to synchronize with Acquire loads in polling_loop - self.polling_active.store(false, Ordering::Release); - Ok(()) - } -} - -impl QueueInfoPoller { - pub fn new(indexer_url: String, api_key: Option) -> Self { - let indexer = PhotonIndexer::new(format!("{}/v1", indexer_url), api_key); - - Self { - indexer, - tree_notifiers: HashMap::new(), - polling_active: Arc::new(AtomicBool::new(true)), - } - } - - async fn poll_queue_info(&mut self) -> Result> { - match self.indexer.get_queue_info(None).await { - Ok(response) => { - let result = response.value; - let slot = result.slot; - - let queue_infos = result - .queues - .into_iter() - .map(|queue| QueueInfo { - tree: queue.tree, - queue: queue.queue, - queue_type: QueueType::from(queue.queue_type as u64), - queue_size: queue.queue_size, - slot, - }) - .collect(); - - Ok(queue_infos) - } - Err(e) => { - error!("Failed to call getQueueInfo: {:?}", e); - Err(anyhow::anyhow!("Failed to call getQueueInfo").context(e)) - } - } - } - - fn distribute_updates(&self, queue_infos: Vec) { - for info in queue_infos { - if let Some(tx) = self.tree_notifiers.get(&info.tree) { - let message = QueueUpdateMessage { - tree: info.tree, - queue: info.queue, - queue_type: info.queue_type, - queue_size: info.queue_size, - slot: info.slot, - }; - - match tx.try_send(message.clone()) { - Ok(()) => { - trace!( - "Routed update to tree {}: {} items (type: {:?})", - info.tree, - message.queue_size, - info.queue_type - ); - } - Err(mpsc::error::TrySendError::Full(_)) => { - debug!( - "Tree {} channel full, dropping update (tree processing slower than updates)", - info.tree - ); - } - Err(mpsc::error::TrySendError::Closed(_)) => { - trace!("Tree {} channel closed (task likely finished)", info.tree); - } - } - } - } - } -} - -#[derive(Debug, Clone)] -struct QueueInfo { - tree: Pubkey, - queue: Pubkey, - queue_type: QueueType, - queue_size: u64, - slot: u64, -} - -#[derive(Debug, Clone)] -pub struct RegisterTree { - pub tree_pubkey: Pubkey, -} - -impl Message for QueueInfoPoller { - type Reply = mpsc::Receiver; - - async fn handle( - &mut self, - msg: RegisterTree, - _ctx: &mut kameo::message::Context, - ) -> Self::Reply { - let (tx, rx) = mpsc::channel(100); - - // Check if there's already a sender registered for this tree - if let Some(old_sender) = self.tree_notifiers.insert(msg.tree_pubkey, tx) { - warn!( - "Double registration detected for tree {}. Replacing existing sender (previous receiver will be closed).", - msg.tree_pubkey - ); - // The old sender is dropped here, which will close the old receiver - drop(old_sender); - } else { - debug!("Registered tree {} for queue updates", msg.tree_pubkey); - } - - rx - } -} - -#[derive(Debug, Clone)] -pub struct UnregisterTree { - pub tree_pubkey: Pubkey, -} - -impl Message for QueueInfoPoller { - type Reply = (); - - async fn handle( - &mut self, - msg: UnregisterTree, - _ctx: &mut kameo::message::Context, - ) -> Self::Reply { - // Check if the tree was actually registered before removing - if let Some(sender) = self.tree_notifiers.remove(&msg.tree_pubkey) { - debug!("Unregistered tree {}", msg.tree_pubkey); - // Drop the sender to close the receiver - drop(sender); - } else { - warn!( - "Attempted to unregister non-existent tree {}. This may indicate a mismatch between receiver drops and explicit unregistration.", - msg.tree_pubkey - ); - } - } -} - -#[derive(Debug, Clone, Copy)] -pub struct RegisteredTreeCount; - -impl Message for QueueInfoPoller { - type Reply = usize; - - async fn handle( - &mut self, - _msg: RegisteredTreeCount, - _ctx: &mut kameo::message::Context, - ) -> Self::Reply { - self.tree_notifiers.len() - } -} - -#[derive(Debug, Clone, Copy)] -struct PollNow; - -impl Message for QueueInfoPoller { - type Reply = (); - - async fn handle( - &mut self, - _msg: PollNow, - _ctx: &mut kameo::message::Context, - ) -> Self::Reply { - if self.tree_notifiers.is_empty() { - debug!("No trees registered; skipping queue info poll"); - return; - } - - match self.poll_queue_info().await { - Ok(queue_infos) => { - self.distribute_updates(queue_infos); - } - Err(e) => { - error!("Failed to poll queue info: {:?}", e); - } - } - } -} - -async fn polling_loop(actor_ref: ActorRef, polling_active: Arc) { - info!("Starting queue info polling loop (1 second interval)"); - - let mut interval = tokio::time::interval(Duration::from_secs(POLLING_INTERVAL_SECS)); - interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); - - loop { - // Check if polling should continue - if !polling_active.load(Ordering::Acquire) { - info!("Polling loop shutting down (polling_active=false)"); - break; - } - - interval.tick().await; - - // Check again after the tick in case shutdown was signaled during sleep - if !polling_active.load(Ordering::Acquire) { - info!("Polling loop shutting down (polling_active=false)"); - break; - } - - match actor_ref.tell(PollNow).send().await { - Ok(_) => {} - Err(e) => { - if polling_active.load(Ordering::Acquire) { - error!("Failed to send poll message to actor: {:?}", e); - } else { - info!("Poll message send failed during shutdown: {:?}", e); - } - break; - } - } - } - - info!("Polling loop stopped"); -} diff --git a/forester/src/processor/tx_cache.rs b/forester/src/processor/tx_cache.rs index a2b8d5d120..6cbcc9b90f 100644 --- a/forester/src/processor/tx_cache.rs +++ b/forester/src/processor/tx_cache.rs @@ -3,6 +3,8 @@ use std::{collections::HashMap, time::Duration}; use tokio::time::Instant; use tracing::{trace, warn}; +const CLEANUP_INTERVAL: Duration = Duration::from_secs(5); + #[derive(Debug, Clone)] struct CacheEntry { timestamp: Instant, @@ -13,6 +15,7 @@ struct CacheEntry { pub struct ProcessedHashCache { entries: HashMap, ttl: Duration, + last_cleanup: Instant, } impl ProcessedHashCache { @@ -20,10 +23,12 @@ impl ProcessedHashCache { Self { entries: HashMap::new(), ttl: Duration::from_secs(ttl_seconds), + last_cleanup: Instant::now(), } } pub fn add(&mut self, hash: &str) { + self.maybe_cleanup(); self.entries.insert( hash.to_string(), CacheEntry { @@ -34,6 +39,7 @@ impl ProcessedHashCache { } pub fn add_with_timeout(&mut self, hash: &str, timeout: Duration) { + self.maybe_cleanup(); self.entries.insert( hash.to_string(), CacheEntry { @@ -50,7 +56,7 @@ impl ProcessedHashCache { } pub fn contains(&mut self, hash: &str) -> bool { - self.cleanup(); + self.maybe_cleanup(); if let Some(entry) = self.entries.get(hash) { let age = Instant::now().duration_since(entry.timestamp); if age > Duration::from_secs(60) && age < entry.timeout { @@ -73,8 +79,15 @@ impl ProcessedHashCache { .map(|entry| Instant::now().duration_since(entry.timestamp)) } - pub fn cleanup(&mut self) { + #[inline] + fn maybe_cleanup(&mut self) { let now = Instant::now(); + if now.duration_since(self.last_cleanup) >= CLEANUP_INTERVAL { + self.cleanup_internal(now); + } + } + + fn cleanup_internal(&mut self, now: Instant) { self.entries.retain(|hash, entry| { let age = now.duration_since(entry.timestamp); let should_keep = age < entry.timeout; @@ -86,6 +99,11 @@ impl ProcessedHashCache { } should_keep }); + self.last_cleanup = now; + } + + pub fn cleanup(&mut self) { + self.cleanup_internal(Instant::now()); } pub fn cleanup_by_key(&mut self, key: &str) { diff --git a/forester/src/processor/v1/helpers.rs b/forester/src/processor/v1/helpers.rs index ddd22362c4..1b4e4609f6 100644 --- a/forester/src/processor/v1/helpers.rs +++ b/forester/src/processor/v1/helpers.rs @@ -348,11 +348,15 @@ pub fn calculate_compute_unit_price(target_lamports: u64, compute_units: u64) -> #[allow(dead_code)] pub fn get_capped_priority_fee(cap_config: CapConfig) -> u64 { if cap_config.max_fee_lamports < cap_config.min_fee_lamports { - panic!("Max fee is less than min fee"); + warn!( + "Invalid priority fee cap config: max_fee_lamports ({}) < min_fee_lamports ({}); clamping max to min", + cap_config.max_fee_lamports, cap_config.min_fee_lamports + ); } + let max_fee_lamports = cap_config.max_fee_lamports.max(cap_config.min_fee_lamports); let priority_fee_max = - calculate_compute_unit_price(cap_config.max_fee_lamports, cap_config.compute_unit_limit); + calculate_compute_unit_price(max_fee_lamports, cap_config.compute_unit_limit); let priority_fee_min = calculate_compute_unit_price(cap_config.min_fee_lamports, cap_config.compute_unit_limit); let capped_fee = std::cmp::min(cap_config.rec_fee_microlamports_per_cu, priority_fee_max); diff --git a/forester/src/processor/v2/address.rs b/forester/src/processor/v2/address.rs deleted file mode 100644 index d2d7a711bb..0000000000 --- a/forester/src/processor/v2/address.rs +++ /dev/null @@ -1,60 +0,0 @@ -use anyhow::Error; -use borsh::BorshSerialize; -use forester_utils::instructions::address_batch_update::{ - get_address_update_instruction_stream, AddressUpdateConfig, -}; -use futures::stream::{Stream, StreamExt}; -use light_batched_merkle_tree::merkle_tree::InstructionDataAddressAppendInputs; -use light_client::rpc::Rpc; -use light_registry::account_compression_cpi::sdk::create_batch_update_address_tree_instruction; -use solana_program::instruction::Instruction; -use solana_sdk::signer::Signer; -use tracing::instrument; - -use super::common::{process_stream, BatchContext, ParsedMerkleTreeData}; -use crate::Result; - -async fn create_stream_future( - ctx: &BatchContext, - merkle_tree_data: ParsedMerkleTreeData, -) -> Result<( - impl Stream>> + Send, - u16, -)> -where - R: Rpc, -{ - let config = AddressUpdateConfig { - rpc_pool: ctx.rpc_pool.clone(), - merkle_tree_pubkey: ctx.merkle_tree, - prover_url: ctx.prover_config.address_append_url.clone(), - prover_api_key: ctx.prover_config.api_key.clone(), - polling_interval: ctx.prover_config.polling_interval, - max_wait_time: ctx.prover_config.max_wait_time, - }; - let (stream, size) = get_address_update_instruction_stream(config, merkle_tree_data) - .await - .map_err(Error::from)?; - let stream = stream.map(|item| item.map_err(Error::from)); - Ok((stream, size)) -} - -#[instrument(level = "debug", skip(context, merkle_tree_data), fields(tree = %context.merkle_tree))] -pub(crate) async fn process_batch( - context: &BatchContext, - merkle_tree_data: ParsedMerkleTreeData, -) -> Result { - let instruction_builder = |data: &InstructionDataAddressAppendInputs| -> Instruction { - let serialized_data = data.try_to_vec().unwrap(); - create_batch_update_address_tree_instruction( - context.authority.pubkey(), - context.derivation, - context.merkle_tree, - context.epoch, - serialized_data, - ) - }; - - let stream_future = create_stream_future(context, merkle_tree_data); - process_stream(context, stream_future, instruction_builder).await -} diff --git a/forester/src/processor/v2/batch_job_builder.rs b/forester/src/processor/v2/batch_job_builder.rs new file mode 100644 index 0000000000..90187a2d25 --- /dev/null +++ b/forester/src/processor/v2/batch_job_builder.rs @@ -0,0 +1,20 @@ +use crate::processor::v2::proof_worker::ProofInput; + +pub trait BatchJobBuilder { + /// Build proof job for a batch. Returns: + /// - `Ok(Some((input, root)))` - batch processed, proof job created + /// - `Ok(None)` - batch should be skipped (e.g., overlap with already-processed data) + /// - `Err(...)` - fatal error, stop processing + fn build_proof_job( + &mut self, + batch_idx: usize, + zkp_batch_size: u64, + epoch: u64, + tree: &str, + ) -> crate::Result>; + + fn available_batches(&self, zkp_batch_size: u64) -> usize { + let _ = zkp_batch_size; + usize::MAX + } +} diff --git a/forester/src/processor/v2/common.rs b/forester/src/processor/v2/common.rs index 84f1da8c55..9f0616c9d4 100644 --- a/forester/src/processor/v2/common.rs +++ b/forester/src/processor/v2/common.rs @@ -1,5 +1,4 @@ use std::{ - future::Future, sync::{ atomic::{AtomicU64, Ordering}, Arc, @@ -7,29 +6,63 @@ use std::{ time::Duration, }; -use borsh::BorshSerialize; -use forester_utils::{ - forester_epoch::EpochPhases, rpc_pool::SolanaRpcPool, utils::wait_for_indexer, -}; +use forester_utils::{forester_epoch::EpochPhases, rpc_pool::SolanaRpcPool}; pub use forester_utils::{ParsedMerkleTreeData, ParsedQueueData}; -use futures::{pin_mut, stream::StreamExt, Stream}; -use light_batched_merkle_tree::{ - batch::BatchState, merkle_tree::BatchedMerkleTreeAccount, queue::BatchedQueueAccount, -}; use light_client::rpc::Rpc; -use light_compressed_account::TreeType; use light_registry::protocol_config::state::EpochState; -use solana_sdk::{instruction::Instruction, pubkey::Pubkey, signature::Keypair, signer::Signer}; +use solana_sdk::{ + address_lookup_table::AddressLookupTableAccount, instruction::Instruction, pubkey::Pubkey, + signature::Keypair, signer::Signer, +}; use tokio::sync::Mutex; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, warn}; -use super::address; +use super::{errors::V2Error, proof_worker::ProofJob}; use crate::{ errors::ForesterError, processor::tx_cache::ProcessedHashCache, slot_tracker::SlotTracker, Result, }; -const SLOTS_STOP_THRESHOLD: u64 = 1; +const SLOTS_STOP_THRESHOLD: u64 = 3; + +#[derive(Debug)] +pub struct WorkerPool { + pub job_tx: async_channel::Sender, +} + +pub fn clamp_to_u16(value: u64, name: &str) -> u16 { + match value.try_into() { + Ok(v) => v, + Err(_) => { + tracing::warn!( + "{} {} exceeds u16::MAX, clamping to {}", + name, + value, + u16::MAX + ); + u16::MAX + } + } +} + +#[inline] +pub fn batch_range(zkp_batch_size: u64, total_len: usize, start: usize) -> std::ops::Range { + let end = (start + zkp_batch_size as usize).min(total_len); + start..end +} + +pub fn get_leaves_hashchain( + leaves_hash_chains: &[[u8; 32]], + batch_idx: usize, +) -> crate::Result<[u8; 32]> { + leaves_hash_chains.get(batch_idx).copied().ok_or_else(|| { + anyhow::anyhow!( + "Missing leaves_hash_chain for batch {} (available: {})", + batch_idx, + leaves_hash_chains.len() + ) + }) +} #[derive(Debug, Clone)] pub struct ProverConfig { @@ -41,16 +74,6 @@ pub struct ProverConfig { pub max_wait_time: Duration, } -#[derive(Debug)] -#[allow(clippy::enum_variant_names)] -pub enum BatchReadyState { - NotReady, - AddressReady { - merkle_tree_data: ParsedMerkleTreeData, - }, - StateReady, -} - #[derive(Debug)] pub struct BatchContext { pub rpc_pool: Arc>, @@ -59,7 +82,7 @@ pub struct BatchContext { pub epoch: u64, pub merkle_tree: Pubkey, pub output_queue: Pubkey, - pub prover_config: ProverConfig, + pub prover_config: Arc, pub ops_cache: Arc>, pub epoch_phases: EpochPhases, pub slot_tracker: Arc, @@ -67,6 +90,11 @@ pub struct BatchContext { pub output_queue_hint: Option, pub num_proof_workers: usize, pub forester_eligibility_end_slot: Arc, + pub address_lookup_tables: Arc>, + /// Maximum attempts to confirm a transaction before timing out. + pub confirmation_max_attempts: u32, + /// Interval between confirmation polling attempts. + pub confirmation_poll_interval: Duration, } impl Clone for BatchContext { @@ -86,115 +114,13 @@ impl Clone for BatchContext { output_queue_hint: self.output_queue_hint, num_proof_workers: self.num_proof_workers, forester_eligibility_end_slot: self.forester_eligibility_end_slot.clone(), + address_lookup_tables: self.address_lookup_tables.clone(), + confirmation_max_attempts: self.confirmation_max_attempts, + confirmation_poll_interval: self.confirmation_poll_interval, } } } -#[derive(Debug)] -pub struct BatchProcessor { - context: BatchContext, - tree_type: TreeType, -} - -pub(crate) async fn process_stream( - context: &BatchContext, - stream_creator_future: FutC, - instruction_builder: impl Fn(&D) -> Instruction, -) -> Result -where - R: Rpc, - S: Stream>> + Send, - D: BorshSerialize, - FutC: Future> + Send, -{ - trace!("Executing batched stream processor (hybrid)"); - - let (batch_stream, zkp_batch_size) = stream_creator_future.await?; - - if zkp_batch_size == 0 { - trace!("ZKP batch size is 0, no work to do."); - return Ok(0); - } - - pin_mut!(batch_stream); - let mut total_instructions_processed = 0; - - while let Some(batch_result) = batch_stream.next().await { - let instruction_batch = batch_result?; - - if instruction_batch.is_empty() { - continue; - } - - let current_slot = context.slot_tracker.estimated_current_slot(); - let forester_end = context - .forester_eligibility_end_slot - .load(Ordering::Acquire); - let eligibility_end_slot = if forester_end > 0 { - forester_end - } else { - context.epoch_phases.active.end - }; - let slots_remaining = eligibility_end_slot.saturating_sub(current_slot); - - if slots_remaining < SLOTS_STOP_THRESHOLD { - info!( - "Only {} slots remaining until eligibility ends (threshold {}), stopping batch processing", - slots_remaining, SLOTS_STOP_THRESHOLD - ); - if !instruction_batch.is_empty() { - let instructions: Vec = - instruction_batch.iter().map(&instruction_builder).collect(); - let _ = send_transaction_batch(context, instructions).await; - } - break; - } - - let instructions: Vec = - instruction_batch.iter().map(&instruction_builder).collect(); - - match send_transaction_batch(context, instructions.clone()).await { - Ok(sig) => { - total_instructions_processed += instruction_batch.len(); - debug!( - "Successfully processed batch with {} instructions, signature: {}", - instruction_batch.len(), - sig - ); - - { - let rpc = context.rpc_pool.get_connection().await?; - wait_for_indexer(&*rpc) - .await - .map_err(|e| anyhow::anyhow!("Error waiting for indexer: {:?}", e))?; - } - } - Err(e) => { - if let Some(ForesterError::NotInActivePhase) = e.downcast_ref::() { - info!("Active phase ended while processing batches, stopping gracefully"); - break; - } else { - error!( - "Failed to process batch with {} instructions for tree {}: {:?}", - instructions.len(), - context.merkle_tree, - e - ); - return Err(e); - } - } - } - } - - if total_instructions_processed == 0 { - trace!("No instructions were processed from the stream."); - return Ok(0); - } - - let total_items_processed = total_instructions_processed * zkp_batch_size as usize; - Ok(total_items_processed) -} - pub(crate) async fn send_transaction_batch( context: &BatchContext, instructions: Vec, @@ -204,7 +130,7 @@ pub(crate) async fn send_transaction_batch( if current_phase_state != EpochState::Active { debug!( - "!! Skipping transaction send: not in active phase (current phase: {:?}, slot: {})", + "Skipping transaction send: not in active phase (current phase: {:?}, slot: {})", current_phase_state, current_slot ); return Err(ForesterError::NotInActivePhase.into()); @@ -233,243 +159,82 @@ pub(crate) async fn send_transaction_batch( context.merkle_tree ); let mut rpc = context.rpc_pool.get_connection().await?; - let signature = rpc - .create_and_send_transaction( + + let signature = if !context.address_lookup_tables.is_empty() { + debug!( + "Using versioned transaction with {} lookup tables", + context.address_lookup_tables.len() + ); + rpc.create_and_send_versioned_transaction( &instructions, &context.authority.pubkey(), &[context.authority.as_ref()], + &context.address_lookup_tables, ) - .await?; + .await? + } else { + rpc.create_and_send_transaction( + &instructions, + &context.authority.pubkey(), + &[context.authority.as_ref()], + ) + .await? + }; - // Ensure transaction is confirmed before returning debug!("Waiting for transaction confirmation: {}", signature); - let confirmed = rpc.confirm_transaction(signature).await?; - if !confirmed { - return Err(anyhow::anyhow!( - "Transaction {} failed to confirm for tree {}", - signature, - context.merkle_tree - )); - } - - info!( - "Transaction confirmed successfully: {} for tree: {}", - signature, context.merkle_tree - ); - Ok(signature.to_string()) -} + let max_attempts = context.confirmation_max_attempts; + let poll_interval = context.confirmation_poll_interval; -impl BatchProcessor { - pub fn new(context: BatchContext, tree_type: TreeType) -> Self { - Self { context, tree_type } - } + for attempt in 0..max_attempts { + let statuses = rpc.get_signature_statuses(&[signature]).await?; - pub async fn process(&self) -> Result { - trace!( - "Starting batch processing for tree type: {:?}", - self.tree_type - ); - let state = self.verify_batch_ready().await; - - match state { - BatchReadyState::AddressReady { merkle_tree_data } => { - trace!( - "Processing address append for tree: {}", - self.context.merkle_tree + if let Some(Some(status)) = statuses.first() { + if let Some(err) = &status.err { + error!( + "transaction {} failed for tree {}: {:?}", + signature, context.merkle_tree, err ); - - let batch_hash = format!( - "address_batch_{}_{}", - self.context.merkle_tree, self.context.epoch - ); - { - let mut cache = self.context.ops_cache.lock().await; - if cache.contains(&batch_hash) { - debug!("Skipping already processed address batch: {}", batch_hash); - return Ok(0); - } - cache.add(&batch_hash); - } - - let result = address::process_batch(&self.context, merkle_tree_data).await; - - if let Err(ref e) = result { - error!( - "Address append failed for tree {}: {:?}", - self.context.merkle_tree, e - ); - } - - let mut cache = self.context.ops_cache.lock().await; - cache.cleanup_by_key(&batch_hash); - trace!("Cache cleaned up for batch: {}", batch_hash); - - result - } - BatchReadyState::StateReady => { - trace!( - "State processing handled by supervisor pipeline; skipping legacy processor for tree {}", - self.context.merkle_tree - ); - Ok(0) + return Err(V2Error::from_transaction_error(context.merkle_tree, err).into()); } - BatchReadyState::NotReady => { - trace!( - "Batch not ready for processing, tree: {}", - self.context.merkle_tree - ); - Ok(0) - } - } - } - - async fn verify_batch_ready(&self) -> BatchReadyState { - let rpc = match self.context.rpc_pool.get_connection().await { - Ok(rpc) => rpc, - Err(_) => return BatchReadyState::NotReady, - }; - - let merkle_tree_account = rpc - .get_account(self.context.merkle_tree) - .await - .ok() - .flatten(); - let output_queue_account = if self.tree_type == TreeType::StateV2 { - rpc.get_account(self.context.output_queue) - .await - .ok() - .flatten() - } else { - None - }; - let (merkle_tree_data, input_ready) = if let Some(mut account) = merkle_tree_account { - match self.parse_merkle_tree_account(&mut account) { - Ok((data, ready)) => (Some(data), ready), - Err(_) => (None, false), + // Transaction succeeded - check confirmation status + // confirmations == None means finalized, Some(n) means n confirmations + let is_confirmed = status.confirmations.is_none() || status.confirmations >= Some(1); + if is_confirmed { + info!( + "Transaction confirmed successfully: {} for tree: {} (slot: {}, confirmations: {:?})", + signature, context.merkle_tree, status.slot, status.confirmations + ); + return Ok(signature.to_string()); } - } else { - (None, false) - }; - if self.tree_type == TreeType::AddressV2 { - return if input_ready { - if let Some(merkle_tree_data) = merkle_tree_data { - BatchReadyState::AddressReady { merkle_tree_data } - } else { - BatchReadyState::NotReady - } - } else { - BatchReadyState::NotReady - }; - } - - // StateV2: check output queue readiness - let output_ready = if let Some(mut account) = output_queue_account { - self.parse_output_queue_account(&mut account) - .map(|(_, ready)| ready) - .unwrap_or(false) + debug!( + "Transaction {} pending confirmation (attempt {}/{}, confirmations: {:?})", + signature, + attempt + 1, + max_attempts, + status.confirmations + ); } else { - false - }; - - if !input_ready && !output_ready { - return BatchReadyState::NotReady; - } - - // State batch is ready; StateSupervisor will handle the actual processing - BatchReadyState::StateReady - } - - fn parse_merkle_tree_account( - &self, - account: &mut solana_sdk::account::Account, - ) -> Result<(ParsedMerkleTreeData, bool)> { - let merkle_tree = match self.tree_type { - TreeType::AddressV2 => BatchedMerkleTreeAccount::address_from_bytes( - account.data.as_mut_slice(), - &self.context.merkle_tree.into(), - ), - TreeType::StateV2 => BatchedMerkleTreeAccount::state_from_bytes( - account.data.as_mut_slice(), - &self.context.merkle_tree.into(), - ), - _ => return Err(ForesterError::InvalidTreeType(self.tree_type).into()), - }?; - - let batch_index = merkle_tree.queue_batches.pending_batch_index; - let batch = merkle_tree - .queue_batches - .batches - .get(batch_index as usize) - .ok_or_else(|| anyhow::anyhow!("Batch not found"))?; - - let num_inserted_zkps = batch.get_num_inserted_zkps(); - let current_zkp_batch_index = batch.get_current_zkp_batch_index(); - - let mut leaves_hash_chains = Vec::new(); - for i in num_inserted_zkps..current_zkp_batch_index { - leaves_hash_chains - .push(merkle_tree.hash_chain_stores[batch_index as usize][i as usize]); + debug!( + "Transaction {} not yet visible (attempt {}/{})", + signature, + attempt + 1, + max_attempts + ); } - let onchain_root = *merkle_tree - .root_history - .last() - .ok_or_else(|| anyhow::anyhow!("Merkle tree root history is empty"))?; - - let parsed_data = ParsedMerkleTreeData { - next_index: merkle_tree.next_index, - current_root: onchain_root, - root_history: merkle_tree.root_history.to_vec(), - zkp_batch_size: batch.zkp_batch_size as u16, - pending_batch_index: batch_index as u32, - num_inserted_zkps, - current_zkp_batch_index, - batch_start_index: batch.start_index, - leaves_hash_chains, - }; - - let is_ready = batch.get_state() != BatchState::Inserted - && batch.get_current_zkp_batch_index() > batch.get_num_inserted_zkps(); - - Ok((parsed_data, is_ready)) + tokio::time::sleep(poll_interval).await; } - fn parse_output_queue_account( - &self, - account: &mut solana_sdk::account::Account, - ) -> Result<(ParsedQueueData, bool)> { - let output_queue = BatchedQueueAccount::output_from_bytes(account.data.as_mut_slice())?; - - let batch_index = output_queue.batch_metadata.pending_batch_index; - let batch = output_queue - .batch_metadata - .batches - .get(batch_index as usize) - .ok_or_else(|| anyhow::anyhow!("Batch not found"))?; - - let num_inserted_zkps = batch.get_num_inserted_zkps(); - let current_zkp_batch_index = batch.get_current_zkp_batch_index(); - - let mut leaves_hash_chains = Vec::new(); - for i in num_inserted_zkps..current_zkp_batch_index { - leaves_hash_chains - .push(output_queue.hash_chain_stores[batch_index as usize][i as usize]); - } - - let parsed_data = ParsedQueueData { - zkp_batch_size: output_queue.batch_metadata.zkp_batch_size as u16, - pending_batch_index: batch_index as u32, - num_inserted_zkps, - current_zkp_batch_index, - leaves_hash_chains, - }; - - let is_ready = batch.get_state() != BatchState::Inserted - && batch.get_current_zkp_batch_index() > batch.get_num_inserted_zkps(); - - Ok((parsed_data, is_ready)) + warn!( + "Transaction {} timed out waiting for confirmation for tree {}", + signature, context.merkle_tree + ); + Err(V2Error::TransactionTimeout { + signature: signature.to_string(), + context: format!("waiting for confirmation for tree {}", context.merkle_tree), } + .into()) } diff --git a/forester/src/processor/v2/errors.rs b/forester/src/processor/v2/errors.rs new file mode 100644 index 0000000000..8512ffc0ff --- /dev/null +++ b/forester/src/processor/v2/errors.rs @@ -0,0 +1,141 @@ +use std::fmt; + +use solana_sdk::{instruction::InstructionError, pubkey::Pubkey, transaction::TransactionError}; +use thiserror::Error; + +/// Matches `light_verifier::VerifierError::ProofVerificationFailed`. +const PROOF_VERIFICATION_FAILED_ERROR_CODE: u32 = 13006; + +fn fmt_root_prefix(root: &[u8; 32]) -> String { + format!( + "{:02x}{:02x}{:02x}{:02x}", + root[0], root[1], root[2], root[3] + ) +} + +#[derive(Debug, Error)] +pub enum V2Error { + #[error("{}", .0)] + RootMismatch(#[from] RootMismatchError), + + #[error("{}", .0)] + IndexerLag(#[from] IndexerLagError), + + #[error("stale tree for tree_id {tree_id}: {details}")] + StaleTree { tree_id: String, details: String }, + + #[error("proof patch failed for tree_id {tree_id}: {details}")] + ProofPatchFailed { tree_id: String, details: String }, + + #[error("hashchain mismatch for tree_id {tree_id}: {details}")] + HashchainMismatch { tree_id: String, details: String }, + + #[error("circuit constraint failure for tree {tree}: code={code:?} {message}")] + CircuitConstraint { + tree: Pubkey, + code: Option, + message: String, + }, + + #[error("transaction failed for tree {tree}: {message}")] + TransactionFailed { tree: Pubkey, message: String }, + + #[error("transaction {signature} timed out: {context}")] + TransactionTimeout { signature: String, context: String }, +} + +impl V2Error { + pub fn from_transaction_error(tree: Pubkey, err: &TransactionError) -> Self { + let message = format!("{:?}", err); + let custom_code = match err { + TransactionError::InstructionError(_, InstructionError::Custom(code)) => Some(*code), + _ => None, + }; + + if matches!(custom_code, Some(PROOF_VERIFICATION_FAILED_ERROR_CODE)) { + return V2Error::CircuitConstraint { + tree, + code: custom_code, + message, + }; + } + + V2Error::TransactionFailed { tree, message } + } + + pub fn is_constraint(&self) -> bool { + matches!(self, V2Error::CircuitConstraint { .. }) + } + + pub fn is_hashchain_mismatch(&self) -> bool { + matches!(self, V2Error::HashchainMismatch { .. }) + } + + pub fn root_mismatch( + tree: Pubkey, + expected: [u8; 32], + indexer: [u8; 32], + onchain: [u8; 32], + ) -> Self { + RootMismatchError { + tree, + expected, + indexer, + onchain, + } + .into() + } + + pub fn indexer_lag(tree: Pubkey, expected: [u8; 32], indexer: [u8; 32]) -> Self { + IndexerLagError { + tree, + expected, + indexer, + } + .into() + } +} + +#[derive(Debug)] +pub struct RootMismatchError { + pub tree: Pubkey, + pub expected: [u8; 32], + pub indexer: [u8; 32], + pub onchain: [u8; 32], +} + +impl fmt::Display for RootMismatchError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "root mismatch for tree {}: expected {}, indexer {}, onchain {}", + self.tree, + fmt_root_prefix(&self.expected), + fmt_root_prefix(&self.indexer), + fmt_root_prefix(&self.onchain) + ) + } +} + +impl std::error::Error for RootMismatchError {} + +#[derive(Debug)] +pub struct IndexerLagError { + pub tree: Pubkey, + pub expected: [u8; 32], + pub indexer: [u8; 32], +} + +impl fmt::Display for IndexerLagError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "indexer lag for tree {}: expected {}, indexer {}", + self.tree, + fmt_root_prefix(&self.expected), + fmt_root_prefix(&self.indexer) + ) + } +} + +impl std::error::Error for IndexerLagError {} diff --git a/forester/src/processor/v2/helpers.rs b/forester/src/processor/v2/helpers.rs new file mode 100644 index 0000000000..d58723c78e --- /dev/null +++ b/forester/src/processor/v2/helpers.rs @@ -0,0 +1,717 @@ +use std::{ + collections::{HashMap, HashSet}, + sync::{Arc, Condvar, Mutex, MutexGuard}, +}; + +use anyhow::anyhow; +use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; +use light_client::{ + indexer::{AddressQueueData, Indexer, QueueElementsV2Options, StateQueueData}, + rpc::Rpc, +}; + +use crate::processor::v2::{common::clamp_to_u16, BatchContext}; + +pub(crate) fn lock_recover<'a, T>(mutex: &'a Mutex, name: &'static str) -> MutexGuard<'a, T> { + match mutex.lock() { + Ok(guard) => guard, + Err(poisoned) => { + tracing::warn!("Poisoned mutex (recovering): {}", name); + poisoned.into_inner() + } + } +} + +pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::state_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + )?; + + let batch_index = tree.queue_batches.pending_batch_index; + let batch = tree + .queue_batches + .batches + .get(batch_index as usize) + .ok_or_else(|| anyhow!("Batch not found"))?; + + Ok(batch.zkp_batch_size) +} + +pub async fn fetch_onchain_state_root( + context: &BatchContext, +) -> crate::Result<[u8; 32]> { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::state_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + )?; + + // Get the current root (last entry in root_history) + let root = tree + .root_history + .last() + .copied() + .ok_or_else(|| anyhow!("Root history is empty"))?; + + Ok(root) +} + +pub async fn fetch_address_zkp_batch_size(context: &BatchContext) -> crate::Result { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::address_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + ) + .map_err(|e| anyhow!("Failed to deserialize address tree: {}", e))?; + + let batch_index = tree.queue_batches.pending_batch_index; + let batch = tree + .queue_batches + .batches + .get(batch_index as usize) + .ok_or_else(|| anyhow!("Batch not found"))?; + + Ok(batch.zkp_batch_size) +} + +pub async fn fetch_onchain_address_root( + context: &BatchContext, +) -> crate::Result<[u8; 32]> { + let rpc = context.rpc_pool.get_connection().await?; + let mut account = rpc + .get_account(context.merkle_tree) + .await? + .ok_or_else(|| anyhow!("Merkle tree account not found"))?; + + let tree = BatchedMerkleTreeAccount::address_from_bytes( + account.data.as_mut_slice(), + &context.merkle_tree.into(), + ) + .map_err(|e| anyhow!("Failed to deserialize address tree: {}", e))?; + + let root = tree + .root_history + .last() + .copied() + .ok_or_else(|| anyhow!("Root history is empty"))?; + + Ok(root) +} + +const INDEXER_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); +const ADDRESS_INDEXER_FETCH_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(120); +const PAGE_SIZE_BATCHES: u64 = 5; +const ADDRESS_PAGE_SIZE_BATCHES: u64 = 5; + +pub async fn fetch_paginated_batches( + context: &BatchContext, + total_elements: u64, + zkp_batch_size: u64, +) -> crate::Result> { + if zkp_batch_size == 0 { + return Err(anyhow::anyhow!("zkp_batch_size cannot be zero")); + } + if total_elements == 0 { + return Ok(None); + } + + let page_size_elements = PAGE_SIZE_BATCHES * zkp_batch_size; + if total_elements <= page_size_elements { + tracing::info!( + "fetch_paginated_batches: single page fetch with start_index=None, total_elements={}, page_size={}", + total_elements, page_size_elements + ); + return fetch_batches(context, None, None, total_elements, zkp_batch_size).await; + } + + let num_pages = total_elements.div_ceil(page_size_elements) as usize; + tracing::debug!( + "Parallel fetch: {} elements ({} batches) in {} pages of {} batches each", + total_elements, + total_elements / zkp_batch_size, + num_pages, + PAGE_SIZE_BATCHES + ); + + // Fetch first page with start_index=None to discover the actual first_queue_index + // (queue may have been pruned, so indices don't start at 0) + let first_page = fetch_batches(context, None, None, page_size_elements, zkp_batch_size).await?; + + let Some(first_page_data) = first_page else { + return Ok(None); + }; + + // Get the actual starting indices from the first page response + // IMPORTANT: Only use first_queue_index if the queue actually has elements. + // When queue is empty, photon returns default first_queue_index=0, which would + // cause subsequent pages to request start_index=2500 even though the actual + // queue might start at 149500 (if elements arrive between requests). + let output_first_index: Option = first_page_data + .output_queue + .as_ref() + .filter(|oq| !oq.leaf_indices.is_empty()) + .map(|oq| oq.first_queue_index); + let input_first_index: Option = first_page_data + .input_queue + .as_ref() + .filter(|iq| !iq.leaf_indices.is_empty()) + .map(|iq| iq.first_queue_index); + + tracing::debug!( + "First page fetched: output_first_index={:?}, input_first_index={:?}", + output_first_index, + input_first_index + ); + + // If only one page needed, return the first page result + if num_pages == 1 { + return Ok(Some(first_page_data)); + } + + // Fetch remaining pages in parallel with offsets relative to first_queue_index + // Only request queues for which we have valid first_queue_index from the first page + let mut fetch_futures = Vec::with_capacity(num_pages - 1); + let mut offset = page_size_elements; + + for _page_idx in 1..num_pages { + let page_size = (total_elements - offset).min(page_size_elements); + // Only use Some(index) for queues we actually got data for in the first page + // If first page had no data for a queue, we don't know its first_queue_index + let output_start = output_first_index.map(|idx| idx + offset); + let input_start = input_first_index.map(|idx| idx + offset); + + let ctx = context.clone(); + + fetch_futures.push(async move { + fetch_batches(&ctx, output_start, input_start, page_size, zkp_batch_size).await + }); + + offset += page_size; + } + + let results = futures::future::join_all(fetch_futures).await; + + // Initialize with first page data + let initial_root = first_page_data.initial_root; + let root_seq = first_page_data.root_seq; + let mut nodes_map: HashMap = HashMap::new(); + for (&idx, &hash) in first_page_data + .nodes + .iter() + .zip(first_page_data.node_hashes.iter()) + { + nodes_map.insert(idx, hash); + } + let mut output_queue = first_page_data.output_queue; + let mut input_queue = first_page_data.input_queue; + + // Merge remaining pages + for (page_idx, result) in results.into_iter().enumerate() { + let page = match result? { + Some(data) => data, + None => continue, + }; + + if page.initial_root != initial_root { + tracing::warn!( + "Page {} has different root ({:?} vs {:?}), stopping merge", + page_idx + 1, + &page.initial_root[..4], + &initial_root[..4] + ); + break; + } + + for (&idx, &hash) in page.nodes.iter().zip(page.node_hashes.iter()) { + nodes_map.entry(idx).or_insert(hash); + } + + if let Some(page_oq) = page.output_queue { + if let Some(ref mut oq) = output_queue { + oq.leaf_indices.extend(page_oq.leaf_indices); + oq.account_hashes.extend(page_oq.account_hashes); + oq.old_leaves.extend(page_oq.old_leaves); + oq.leaves_hash_chains.extend(page_oq.leaves_hash_chains); + } else { + output_queue = Some(page_oq); + } + } + + if let Some(page_iq) = page.input_queue { + if let Some(ref mut iq) = input_queue { + iq.leaf_indices.extend(page_iq.leaf_indices); + iq.account_hashes.extend(page_iq.account_hashes); + iq.current_leaves.extend(page_iq.current_leaves); + iq.tx_hashes.extend(page_iq.tx_hashes); + iq.nullifiers.extend(page_iq.nullifiers); + iq.leaves_hash_chains.extend(page_iq.leaves_hash_chains); + } else { + input_queue = Some(page_iq); + } + } + } + + let mut nodes_vec: Vec<_> = nodes_map.into_iter().collect(); + nodes_vec.sort_by_key(|(idx, _)| *idx); + let (nodes, node_hashes): (Vec<_>, Vec<_>) = nodes_vec.into_iter().unzip(); + + tracing::debug!( + "Parallel fetch complete: {} nodes, output={}, input={}", + nodes.len(), + output_queue + .as_ref() + .map(|oq| oq.leaf_indices.len()) + .unwrap_or(0), + input_queue + .as_ref() + .map(|iq| iq.leaf_indices.len()) + .unwrap_or(0) + ); + + Ok(Some(StateQueueData { + nodes, + node_hashes, + initial_root, + root_seq, + output_queue, + input_queue, + })) +} + +pub async fn fetch_batches( + context: &BatchContext, + output_start_index: Option, + input_start_index: Option, + fetch_len: u64, + zkp_batch_size: u64, +) -> crate::Result> { + tracing::info!( + "fetch_batches: tree={}, output_start={:?}, input_start={:?}, fetch_len={}, zkp_batch_size={}", + context.merkle_tree, output_start_index, input_start_index, fetch_len, zkp_batch_size + ); + + let fetch_len_u16 = clamp_to_u16(fetch_len, "fetch_len"); + let zkp_batch_size_u16 = clamp_to_u16(zkp_batch_size, "zkp_batch_size"); + + let mut rpc = context.rpc_pool.get_connection().await?; + let indexer = rpc.indexer_mut()?; + let options = QueueElementsV2Options::default() + .with_output_queue(output_start_index, Some(fetch_len_u16)) + .with_output_queue_batch_size(Some(zkp_batch_size_u16)) + .with_input_queue(input_start_index, Some(fetch_len_u16)) + .with_input_queue_batch_size(Some(zkp_batch_size_u16)); + + let fetch_future = indexer.get_queue_elements(context.merkle_tree.to_bytes(), options, None); + + let res = match tokio::time::timeout(INDEXER_FETCH_TIMEOUT, fetch_future).await { + Ok(result) => result?, + Err(_) => { + tracing::warn!( + "fetch_batches timed out after {:?} for tree {}", + INDEXER_FETCH_TIMEOUT, + context.merkle_tree + ); + return Err(anyhow::anyhow!( + "Indexer fetch timed out after {:?} for state tree {}", + INDEXER_FETCH_TIMEOUT, + context.merkle_tree + )); + } + }; + + Ok(res.value.state_queue) +} + +pub async fn fetch_address_batches( + context: &BatchContext, + output_start_index: Option, + fetch_len: u64, + zkp_batch_size: u64, +) -> crate::Result> { + let fetch_len_u16 = clamp_to_u16(fetch_len, "fetch_len"); + let zkp_batch_size_u16 = clamp_to_u16(zkp_batch_size, "zkp_batch_size"); + + let mut rpc = context.rpc_pool.get_connection().await?; + let indexer = rpc.indexer_mut()?; + + let options = QueueElementsV2Options::default() + .with_address_queue(output_start_index, Some(fetch_len_u16)) + .with_address_queue_batch_size(Some(zkp_batch_size_u16)); + + tracing::debug!( + "fetch_address_batches: tree={}, start={:?}, len={}, zkp_batch_size={}", + context.merkle_tree, + output_start_index, + fetch_len_u16, + zkp_batch_size_u16 + ); + + let fetch_future = indexer.get_queue_elements(context.merkle_tree.to_bytes(), options, None); + + let res = match tokio::time::timeout(ADDRESS_INDEXER_FETCH_TIMEOUT, fetch_future).await { + Ok(result) => result?, + Err(_) => { + tracing::warn!( + "fetch_address_batches timed out after {:?} for tree {}", + ADDRESS_INDEXER_FETCH_TIMEOUT, + context.merkle_tree + ); + return Err(anyhow::anyhow!( + "Indexer fetch timed out after {:?} for address tree {}", + ADDRESS_INDEXER_FETCH_TIMEOUT, + context.merkle_tree + )); + } + }; + + if let Some(ref aq) = res.value.address_queue { + tracing::debug!( + "fetch_address_batches response: address_queue present = true, addresses={}, subtrees={}, leaves_hash_chains={}, start_index={}", + aq.addresses.len(), + aq.subtrees.len(), + aq.leaves_hash_chains.len(), + aq.start_index + ); + } else { + tracing::debug!("fetch_address_batches response: address_queue present = false"); + } + + Ok(res.value.address_queue) +} + +/// Streams address queue data by fetching pages in the background. +/// +/// The first page is fetched synchronously, then subsequent pages are fetched +/// in a background task. Consumers can access data as it becomes available +/// without waiting for the entire fetch to complete. +#[derive(Debug)] +pub struct StreamingAddressQueue { + /// The accumulated address queue data from all fetched pages. + pub data: Arc>, + + /// Number of elements currently available for processing. + /// Paired with `data_ready` condvar for signaling new data. + available_elements: Arc>, + + /// Signaled when new elements become available. + /// Paired with `available_elements` mutex. + data_ready: Arc, + + /// Whether the background fetch has completed (all pages fetched or error). + /// Paired with `fetch_complete_condvar` for signaling completion. + fetch_complete: Arc>, + + /// Signaled when background fetch completes. + /// Paired with `fetch_complete` mutex. + fetch_complete_condvar: Arc, + + /// Number of elements per ZKP batch, used for batch boundary calculations. + zkp_batch_size: usize, +} + +impl StreamingAddressQueue { + /// Waits until at least `batch_end` elements are available or fetch completes. + /// + /// Uses a polling loop to avoid race conditions between the available_elements + /// and fetch_complete mutexes. Returns the number of available elements. + pub fn wait_for_batch(&self, batch_end: usize) -> usize { + const POLL_INTERVAL: std::time::Duration = std::time::Duration::from_millis(5); + + loop { + let available = *lock_recover( + &self.available_elements, + "streaming_address_queue.available_elements", + ); + if available >= batch_end { + return available; + } + + let complete = *lock_recover( + &self.fetch_complete, + "streaming_address_queue.fetch_complete", + ); + if complete { + return available; + } + + std::thread::sleep(POLL_INTERVAL); + } + } + + pub fn get_batch_data(&self, start: usize, end: usize) -> Option { + let available = self.wait_for_batch(end); + if start >= available { + return None; + } + let actual_end = end.min(available); + let data = lock_recover(&self.data, "streaming_address_queue.data"); + Some(BatchDataSlice { + addresses: data.addresses[start..actual_end].to_vec(), + low_element_values: data.low_element_values[start..actual_end].to_vec(), + low_element_next_values: data.low_element_next_values[start..actual_end].to_vec(), + low_element_indices: data.low_element_indices[start..actual_end].to_vec(), + low_element_next_indices: data.low_element_next_indices[start..actual_end].to_vec(), + }) + } + + pub fn into_data(self) -> AddressQueueData { + let mut complete = lock_recover( + &self.fetch_complete, + "streaming_address_queue.fetch_complete", + ); + while !*complete { + complete = match self.fetch_complete_condvar.wait_while(complete, |c| !*c) { + Ok(guard) => guard, + Err(poisoned) => { + tracing::warn!("Poisoned mutex while waiting (recovering): streaming_address_queue.fetch_complete"); + poisoned.into_inner() + } + }; + } + drop(complete); + match Arc::try_unwrap(self.data) { + Ok(mutex) => mutex.into_inner().unwrap_or_else(|poisoned| { + tracing::warn!("Poisoned mutex during into_data (recovering)"); + poisoned.into_inner() + }), + Err(arc) => lock_recover(arc.as_ref(), "streaming_address_queue.data_clone").clone(), + } + } + + pub fn initial_root(&self) -> [u8; 32] { + lock_recover(&self.data, "streaming_address_queue.data").initial_root + } + + pub fn start_index(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").start_index + } + + pub fn subtrees(&self) -> Vec<[u8; 32]> { + lock_recover(&self.data, "streaming_address_queue.data") + .subtrees + .clone() + } + + pub fn root_seq(&self) -> u64 { + lock_recover(&self.data, "streaming_address_queue.data").root_seq + } + + pub fn available_batches(&self) -> usize { + debug_assert!(self.zkp_batch_size != 0, "zkp_batch_size must be non-zero"); + if self.zkp_batch_size == 0 { + tracing::error!("zkp_batch_size is zero, returning 0 batches to avoid panic"); + return 0; + } + let available = *lock_recover( + &self.available_elements, + "streaming_address_queue.available_elements", + ); + available / self.zkp_batch_size + } + + pub fn is_complete(&self) -> bool { + *lock_recover( + &self.fetch_complete, + "streaming_address_queue.fetch_complete", + ) + } +} + +#[derive(Debug, Clone)] +pub struct BatchDataSlice { + pub addresses: Vec<[u8; 32]>, + pub low_element_values: Vec<[u8; 32]>, + pub low_element_next_values: Vec<[u8; 32]>, + pub low_element_indices: Vec, + pub low_element_next_indices: Vec, +} + +pub async fn fetch_streaming_address_batches( + context: &BatchContext, + total_elements: u64, + zkp_batch_size: u64, +) -> crate::Result> { + if total_elements == 0 { + return Ok(None); + } + + let page_size_elements = ADDRESS_PAGE_SIZE_BATCHES * zkp_batch_size; + let num_pages = total_elements.div_ceil(page_size_elements) as usize; + + tracing::info!( + "address fetch: {} elements ({} batches) in {} pages of {} batches each", + total_elements, + total_elements / zkp_batch_size, + num_pages, + ADDRESS_PAGE_SIZE_BATCHES + ); + + let first_page_size = page_size_elements.min(total_elements); + let first_page = + match fetch_address_batches(context, None, first_page_size, zkp_batch_size).await? { + Some(data) if !data.addresses.is_empty() => data, + _ => return Ok(None), + }; + + let initial_elements = first_page.addresses.len(); + let first_page_requested = first_page_size as usize; + + let queue_exhausted = initial_elements < first_page_requested; + + tracing::info!( + "First page fetched: {} addresses ({} batches ready), root={:?}[..4], queue_exhausted={}", + initial_elements, + initial_elements / zkp_batch_size as usize, + &first_page.initial_root[..4], + queue_exhausted + ); + + let streaming = StreamingAddressQueue { + data: Arc::new(Mutex::new(first_page)), + available_elements: Arc::new(Mutex::new(initial_elements)), + data_ready: Arc::new(Condvar::new()), + fetch_complete: Arc::new(Mutex::new(num_pages == 1 || queue_exhausted)), + fetch_complete_condvar: Arc::new(Condvar::new()), + zkp_batch_size: zkp_batch_size as usize, + }; + + if num_pages == 1 || queue_exhausted { + return Ok(Some(streaming)); + } + + let data = Arc::clone(&streaming.data); + let available = Arc::clone(&streaming.available_elements); + let ready = Arc::clone(&streaming.data_ready); + let complete = Arc::clone(&streaming.fetch_complete); + let complete_condvar = Arc::clone(&streaming.fetch_complete_condvar); + let ctx = context.clone(); + let initial_root = streaming.initial_root(); + + // Get the start_index from the first page to calculate offsets for subsequent pages + let first_page_start_index = streaming.start_index(); + + tokio::spawn(async move { + let mut offset = first_page_size; + + for page_idx in 1..num_pages { + let page_size = (total_elements - offset).min(page_size_elements); + // Use absolute index: first page's start_index + relative offset + let absolute_start = Some(first_page_start_index + offset); + + tracing::debug!( + "Fetching address page {}/{}: absolute_start={:?}, size={}", + page_idx + 1, + num_pages, + absolute_start, + page_size + ); + + match fetch_address_batches(&ctx, absolute_start, page_size, zkp_batch_size).await { + Ok(Some(page)) => { + if page.initial_root != initial_root { + tracing::warn!( + "Address page {} has different root ({:?} vs {:?}), stopping fetch", + page_idx + 1, + &page.initial_root[..4], + &initial_root[..4] + ); + break; + } + + let page_elements = page.addresses.len(); + let page_requested = page_size as usize; + + { + let mut data_guard = + lock_recover(data.as_ref(), "streaming_address_queue.data"); + data_guard.addresses.extend(page.addresses); + data_guard + .low_element_values + .extend(page.low_element_values); + data_guard + .low_element_next_values + .extend(page.low_element_next_values); + data_guard + .low_element_indices + .extend(page.low_element_indices); + data_guard + .low_element_next_indices + .extend(page.low_element_next_indices); + data_guard + .leaves_hash_chains + .extend(page.leaves_hash_chains); + let mut seen: HashSet = data_guard.nodes.iter().copied().collect(); + for (&idx, &hash) in page.nodes.iter().zip(page.node_hashes.iter()) { + if seen.insert(idx) { + data_guard.nodes.push(idx); + data_guard.node_hashes.push(hash); + } + } + } + + { + let mut avail = lock_recover( + available.as_ref(), + "streaming_address_queue.available_elements", + ); + *avail += page_elements; + tracing::debug!( + "Page {} fetched: {} elements, total available: {} ({} batches)", + page_idx + 1, + page_elements, + *avail, + *avail / zkp_batch_size as usize + ); + } + ready.notify_all(); + + if page_elements < page_requested { + tracing::debug!( + "Page {} returned fewer elements than requested ({} < {}), queue exhausted", + page_idx + 1, page_elements, page_requested + ); + break; + } + } + Ok(None) => { + tracing::debug!("Page {} returned empty, stopping fetch", page_idx + 1); + break; + } + Err(e) => { + tracing::warn!("Error fetching page {}: {}", page_idx + 1, e); + break; + } + } + + offset += page_size; + } + + { + let mut complete_guard = + lock_recover(complete.as_ref(), "streaming_address_queue.fetch_complete"); + *complete_guard = true; + } + ready.notify_all(); + complete_condvar.notify_all(); + tracing::debug!("Background address fetch complete"); + }); + + Ok(Some(streaming)) +} diff --git a/forester/src/processor/v2/mod.rs b/forester/src/processor/v2/mod.rs index 13ed8ed01c..ef1fdf10b1 100644 --- a/forester/src/processor/v2/mod.rs +++ b/forester/src/processor/v2/mod.rs @@ -1,39 +1,23 @@ -mod address; -mod common; -pub mod state; +mod batch_job_builder; +pub mod common; +pub mod errors; +mod helpers; +mod processor; +pub mod proof_cache; +mod proof_worker; +mod root_guard; +pub mod strategy; +mod tx_sender; -use common::BatchProcessor; -use light_client::rpc::Rpc; -use tracing::{instrument, trace}; +pub use common::{BatchContext, ProverConfig}; +pub use processor::QueueProcessor; +pub use proof_cache::{CachedProof, SharedProofCache}; +pub use tx_sender::{BatchInstruction, ProofTimings, TxSenderResult}; -use crate::Result; +use crate::epoch_manager::ProcessingMetrics; -#[instrument( - level = "debug", - fields( - epoch = context.epoch, - tree = %context.merkle_tree, - tree_type = ?tree_type - ), - skip(context) -)] -pub async fn process_batched_operations( - context: BatchContext, - tree_type: TreeType, -) -> Result { - trace!("process_batched_operations"); - match tree_type { - TreeType::AddressV2 => { - let processor = BatchProcessor::new(context, tree_type); - processor.process().await - } - TreeType::StateV2 => { - trace!("StateV2 processing should be handled through StateSupervisor actor"); - Ok(0) - } - _ => Ok(0), - } +#[derive(Debug, Clone, Default)] +pub struct ProcessingResult { + pub items_processed: usize, + pub metrics: ProcessingMetrics, } - -pub use common::{BatchContext, ProverConfig}; -use light_compressed_account::TreeType; diff --git a/forester/src/processor/v2/processor.rs b/forester/src/processor/v2/processor.rs new file mode 100644 index 0000000000..5f26f4495e --- /dev/null +++ b/forester/src/processor/v2/processor.rs @@ -0,0 +1,671 @@ +use std::{ + sync::{atomic::Ordering, Arc}, + time::{Duration, Instant}, +}; + +use anyhow::anyhow; +use forester_utils::{forester_epoch::EpochPhases, utils::wait_for_indexer}; +use light_client::rpc::Rpc; +use light_compressed_account::QueueType; +use solana_sdk::pubkey::Pubkey; +use tokio::sync::mpsc; +use tracing::{debug, info, warn}; + +use crate::{ + epoch_manager::{CircuitMetrics, ProcessingMetrics}, + processor::v2::{ + batch_job_builder::BatchJobBuilder, + common::WorkerPool, + errors::V2Error, + proof_cache::SharedProofCache, + proof_worker::{spawn_proof_workers, ProofJob, ProofJobResult}, + root_guard::{reconcile_roots, RootReconcileDecision}, + strategy::{CircuitType, QueueData, TreeStrategy}, + tx_sender::{BatchInstruction, ProofTimings, TxSender}, + BatchContext, ProcessingResult, + }, +}; + +const MAX_BATCHES_PER_TREE: usize = 20; + +#[derive(Debug, Default, Clone)] +struct BatchTimings { + append_circuit_inputs: Duration, + nullify_circuit_inputs: Duration, + address_append_circuit_inputs: Duration, + append_count: usize, + nullify_count: usize, + address_append_count: usize, +} + +impl BatchTimings { + fn add_timing(&mut self, circuit_type: CircuitType, duration: Duration) { + match circuit_type { + CircuitType::Append => { + self.append_circuit_inputs += duration; + self.append_count += 1; + } + CircuitType::Nullify => { + self.nullify_circuit_inputs += duration; + self.nullify_count += 1; + } + CircuitType::AddressAppend => { + self.address_append_circuit_inputs += duration; + self.address_append_count += 1; + } + } + } +} + +struct CachedQueueState { + staging_tree: T, + batches_processed: usize, + total_batches: usize, +} + +pub struct QueueProcessor> { + context: BatchContext, + strategy: S, + current_root: [u8; 32], + zkp_batch_size: u64, + seq: u64, + worker_pool: Option, + cached_state: Option>, + proof_cache: Option>, +} + +impl> std::fmt::Debug for QueueProcessor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QueueProcessor") + .field("merkle_tree", &self.context.merkle_tree) + .field("epoch", &self.context.epoch) + .field("zkp_batch_size", &self.zkp_batch_size) + .finish() + } +} + +impl + 'static> QueueProcessor +where + S::StagingTree: BatchJobBuilder, +{ + pub async fn new(context: BatchContext, strategy: S) -> crate::Result { + let zkp_batch_size = strategy.fetch_zkp_batch_size(&context).await?; + let current_root = strategy.fetch_onchain_root(&context).await?; + info!( + "Initializing {} processor for tree {} with on-chain root {:?}[..4]", + strategy.name(), + context.merkle_tree, + ¤t_root[..4] + ); + Ok(Self { + context, + strategy, + current_root, + zkp_batch_size, + seq: 0, + worker_pool: None, + cached_state: None, + proof_cache: None, + }) + } + + pub fn set_proof_cache(&mut self, cache: Arc) { + self.proof_cache = Some(cache); + } + + pub async fn process(&mut self) -> crate::Result { + let queue_size = self.zkp_batch_size * MAX_BATCHES_PER_TREE as u64; + self.process_queue_update(queue_size).await + } + + pub async fn process_queue_update( + &mut self, + queue_size: u64, + ) -> crate::Result { + if queue_size < self.zkp_batch_size { + return Ok(ProcessingResult::default()); + } + + if self.worker_pool.is_none() { + let job_tx = spawn_proof_workers(&self.context.prover_config); + self.worker_pool = Some(WorkerPool { job_tx }); + } + + if let Some(cached) = self.cached_state.take() { + let actual_available = self + .strategy + .available_batches(&cached.staging_tree, self.zkp_batch_size); + let total_batches = if actual_available == usize::MAX { + cached.total_batches + } else { + actual_available + }; + + let remaining = total_batches.saturating_sub(cached.batches_processed); + if remaining > 0 { + info!( + "Using cached state: {} remaining batches (processed {}/{}, actual available: {})", + remaining, cached.batches_processed, total_batches, + if actual_available == usize::MAX { "max".to_string() } else { actual_available.to_string() } + ); + + let batches_to_process = remaining.min(MAX_BATCHES_PER_TREE); + let queue_data = QueueData { + staging_tree: cached.staging_tree, + initial_root: self.current_root, + num_batches: total_batches, + }; + + return self + .process_batches( + queue_data, + cached.batches_processed, + batches_to_process, + total_batches, + ) + .await; + } + } + + let available_batches = (queue_size / self.zkp_batch_size) as usize; + let fetch_batches = available_batches.min(MAX_BATCHES_PER_TREE); + + if available_batches > MAX_BATCHES_PER_TREE { + debug!( + "Queue has {} batches available, fetching {} for {} iterations", + available_batches, + fetch_batches, + available_batches.div_ceil(fetch_batches) + ); + } + + { + let rpc = self.context.rpc_pool.get_connection().await?; + if let Err(e) = wait_for_indexer(&*rpc).await { + warn!("wait_for_indexer error (proceeding anyway): {}", e); + } + } + + let queue_data = match self + .strategy + .fetch_queue_data(&self.context, fetch_batches, self.zkp_batch_size) + .await? + { + Some(data) => data, + None => return Ok(ProcessingResult::default()), + }; + + if self.current_root == [0u8; 32] || queue_data.initial_root == self.current_root { + let total_batches = queue_data.num_batches; + let process_now = total_batches.min(MAX_BATCHES_PER_TREE); + return self + .process_batches(queue_data, 0, process_now, total_batches) + .await; + } + + let onchain_root = self.strategy.fetch_onchain_root(&self.context).await?; + match reconcile_roots(self.current_root, queue_data.initial_root, onchain_root) { + RootReconcileDecision::Proceed => { + let total_batches = queue_data.num_batches; + let process_now = total_batches.min(MAX_BATCHES_PER_TREE); + self.process_batches(queue_data, 0, process_now, total_batches) + .await + } + RootReconcileDecision::WaitForIndexer => { + debug!( + "Indexer root {:?}[..4] doesn't match expected {:?}[..4], on-chain confirms we're ahead. Waiting for next slot.", + &queue_data.initial_root[..4], + &self.current_root[..4] + ); + Ok(ProcessingResult::default()) + } + RootReconcileDecision::ResetToOnchainAndProceed(root) => { + debug!( + "Resetting to on-chain root {:?}[..4] (was expecting {:?}[..4])", + &root[..4], + &self.current_root[..4] + ); + self.current_root = root; + self.cached_state = None; + let total_batches = queue_data.num_batches; + let process_now = total_batches.min(MAX_BATCHES_PER_TREE); + self.process_batches(queue_data, 0, process_now, total_batches) + .await + } + RootReconcileDecision::ResetToOnchainAndStop(root) => { + warn!( + "Root divergence: expected {:?}[..4], indexer {:?}[..4], on-chain {:?}[..4]. Resetting.", + &self.current_root[..4], + &queue_data.initial_root[..4], + &root[..4] + ); + self.current_root = root; + self.cached_state = None; + Ok(ProcessingResult::default()) + } + } + } + + pub async fn clear_cache(&mut self) { + self.cached_state = None; + if let Some(proof_cache) = &self.proof_cache { + proof_cache.clear().await; + } + } + + pub fn update_eligibility(&mut self, end_slot: u64) { + self.context + .forester_eligibility_end_slot + .store(end_slot, Ordering::Relaxed); + } + + pub fn update_epoch(&mut self, new_epoch: u64, new_phases: EpochPhases) { + self.context.epoch = new_epoch; + self.context.epoch_phases = new_phases; + } + + pub fn merkle_tree(&self) -> &Pubkey { + &self.context.merkle_tree + } + + pub fn epoch(&self) -> u64 { + self.context.epoch + } + + pub fn zkp_batch_size(&self) -> u64 { + self.zkp_batch_size + } + + async fn process_batches( + &mut self, + queue_data: QueueData, + batch_offset: usize, + batches_to_process: usize, + total_batches: usize, + ) -> crate::Result { + self.current_root = queue_data.initial_root; + let num_workers = self.context.num_proof_workers.max(1); + let (proof_tx, proof_rx) = mpsc::channel(num_workers * 2); + + let tx_sender_handle = TxSender::spawn( + self.context.clone(), + proof_rx, + self.zkp_batch_size, + self.current_root, + self.proof_cache.clone(), + ); + let job_tx = self + .worker_pool + .as_ref() + .ok_or_else(|| anyhow!("Worker pool not initialized"))? + .job_tx + .clone(); + + let (jobs_sent, timings, staging_tree) = self + .enqueue_jobs( + queue_data, + batch_offset, + batches_to_process, + job_tx, + proof_tx.clone(), + ) + .await?; + + let total_processed = batch_offset + batches_to_process; + let remaining_batches = total_batches.saturating_sub(total_processed); + if remaining_batches > 0 { + debug!( + "Caching {} remaining batches for optimistic continuation (processed {}/{})", + remaining_batches, total_processed, total_batches + ); + self.cached_state = Some(CachedQueueState { + staging_tree, + batches_processed: total_processed, + total_batches, + }); + } else { + self.cached_state = None; + } + + drop(proof_tx); + + let tx_result = tx_sender_handle + .await + .map_err(|e| anyhow!("Tx sender join error: {}", e)) + .and_then(|res| res); + + if let Err(ref e) = tx_result { + if let Some(v2) = e.downcast_ref::() { + if v2.is_constraint() { + warn!( + "Tx sender constraint error for tree {}: {}", + self.context.merkle_tree, e + ); + return Err(tx_result.unwrap_err()); + } + } + } + + let (tx_processed, proof_timings, tx_sending_duration) = match &tx_result { + Ok(result) => ( + result.items_processed, + result.proof_timings.clone(), + result.tx_sending_duration, + ), + Err(e) => { + warn!( + "Tx sender error for tree {}: {}", + self.context.merkle_tree, e + ); + (0, Default::default(), Duration::ZERO) + } + }; + + if tx_processed < jobs_sent * self.zkp_batch_size as usize { + debug!( + "Processed {} items but expected {}, some proofs may have failed", + tx_processed, + jobs_sent * self.zkp_batch_size as usize + ); + } + + let mut metrics = ProcessingMetrics::default(); + + if timings.append_count > 0 { + metrics.append = CircuitMetrics { + circuit_inputs_duration: timings.append_circuit_inputs, + proof_generation_duration: proof_timings.append_proof_duration(), + round_trip_duration: proof_timings.append_round_trip_duration(), + }; + } + if timings.nullify_count > 0 { + metrics.nullify = CircuitMetrics { + circuit_inputs_duration: timings.nullify_circuit_inputs, + proof_generation_duration: proof_timings.nullify_proof_duration(), + round_trip_duration: proof_timings.nullify_round_trip_duration(), + }; + } + if timings.address_append_count > 0 { + metrics.address_append = CircuitMetrics { + circuit_inputs_duration: timings.address_append_circuit_inputs, + proof_generation_duration: proof_timings.address_append_proof_duration(), + round_trip_duration: proof_timings.address_append_round_trip_duration(), + }; + } + metrics.tx_sending_duration = tx_sending_duration; + + if let Err(e) = tx_result { + warn!( + "Returning partial metrics despite error for tree {}: {}", + self.context.merkle_tree, e + ); + } + + Ok(ProcessingResult { + items_processed: tx_processed, + metrics, + }) + } + + async fn enqueue_jobs( + &mut self, + queue_data: QueueData, + batch_offset: usize, + num_batches: usize, + job_tx: async_channel::Sender, + result_tx: mpsc::Sender, + ) -> crate::Result<(usize, BatchTimings, S::StagingTree)> + where + S::StagingTree: 'static, + { + let zkp_batch_size = self.zkp_batch_size; + let strategy = self.strategy.clone(); + let initial_seq = self.seq; + let epoch = self.context.epoch; + let tree = self.context.merkle_tree.to_string(); + + let result = tokio::task::spawn_blocking(move || { + let mut staging_tree = queue_data.staging_tree; + let mut jobs_sent = 0; + let mut final_root = queue_data.initial_root; + let mut current_seq = initial_seq; + let mut timings = BatchTimings::default(); + + let mut skipped_batches = 0usize; + for i in 0..num_batches { + let batch_idx = batch_offset + i; + + let circuit_type = strategy.circuit_type_for_batch(&staging_tree, batch_idx); + + let circuit_start = Instant::now(); + let proof_result = strategy.build_proof_job( + &mut staging_tree, + batch_idx, + zkp_batch_size, + epoch, + &tree, + )?; + let circuit_duration = circuit_start.elapsed(); + + let (inputs, new_root) = match proof_result { + Some(result) => result, + None => { + skipped_batches += 1; + continue; + } + }; + + timings.add_timing(circuit_type, circuit_duration); + + final_root = new_root; + let job = ProofJob { + seq: current_seq, + inputs, + result_tx: result_tx.clone(), + tree_id: tree.clone(), + }; + current_seq += 1; + + job_tx + .send_blocking(job) + .map_err(|e| anyhow::anyhow!("Failed to send job: {}", e))?; + jobs_sent += 1; + } + + if skipped_batches > 0 { + tracing::debug!( + "Skipped {}/{} batches due to overlap", + skipped_batches, + num_batches + ); + } + + Ok::<_, anyhow::Error>((jobs_sent, final_root, current_seq, timings, staging_tree)) + }) + .await + .map_err(|e| anyhow::anyhow!("Blocking task panicked: {}", e))??; + + let (jobs_sent, final_root, final_seq, timings, staging_tree) = result; + + self.current_root = final_root; + self.seq = final_seq; + + Ok((jobs_sent, timings, staging_tree)) + } + + pub async fn prewarm_proofs( + &mut self, + cache: Arc, + queue_size: u64, + ) -> crate::Result { + if queue_size < self.zkp_batch_size { + return Ok(ProcessingResult::default()); + } + + let max_batches = ((queue_size / self.zkp_batch_size) as usize).min(MAX_BATCHES_PER_TREE); + + if self.worker_pool.is_none() { + let job_tx = spawn_proof_workers(&self.context.prover_config); + self.worker_pool = Some(WorkerPool { job_tx }); + } + + let queue_data = match self + .strategy + .fetch_queue_data(&self.context, max_batches, self.zkp_batch_size) + .await? + { + Some(data) => data, + None => return Ok(ProcessingResult::default()), + }; + + self.prewarm_batches(cache, queue_data).await + } + + pub async fn prewarm_from_indexer( + &mut self, + cache: Arc, + _queue_type: QueueType, + max_batches: usize, + ) -> crate::Result { + if max_batches == 0 { + return Ok(ProcessingResult::default()); + } + + let max_batches = max_batches.min(MAX_BATCHES_PER_TREE); + + if self.worker_pool.is_none() { + let job_tx = spawn_proof_workers(&self.context.prover_config); + self.worker_pool = Some(WorkerPool { job_tx }); + } + + let queue_data = match self + .strategy + .fetch_queue_data(&self.context, max_batches, self.zkp_batch_size) + .await? + { + Some(data) => data, + None => return Ok(ProcessingResult::default()), + }; + + self.prewarm_batches(cache, queue_data).await + } + + async fn prewarm_batches( + &mut self, + cache: Arc, + queue_data: QueueData, + ) -> crate::Result { + let initial_root = queue_data.initial_root; + self.current_root = initial_root; + let num_batches = queue_data.num_batches; + let num_workers = self.context.num_proof_workers.max(1); + + cache.start_warming(initial_root).await; + + let (proof_tx, mut proof_rx) = mpsc::channel(num_workers * 2); + + let job_tx = self + .worker_pool + .as_ref() + .ok_or_else(|| anyhow!("Worker pool not initialized"))? + .job_tx + .clone(); + + info!( + "Pre-warming {} proofs for tree {} with root {:?}", + num_batches, + self.context.merkle_tree, + &initial_root[..4] + ); + + let (jobs_sent, timings, _staging_tree) = self + .enqueue_jobs(queue_data, 0, num_batches, job_tx, proof_tx.clone()) + .await?; + + drop(proof_tx); + + let mut proofs_cached = 0; + let mut proof_timings = ProofTimings::default(); + + while let Some(result) = proof_rx.recv().await { + match result.result { + Ok(instruction) => { + match &instruction { + BatchInstruction::Append(_) => { + proof_timings.append_proof_ms += result.proof_duration_ms; + proof_timings.append_round_trip_ms += result.round_trip_ms; + } + BatchInstruction::Nullify(_) => { + proof_timings.nullify_proof_ms += result.proof_duration_ms; + proof_timings.nullify_round_trip_ms += result.round_trip_ms; + } + BatchInstruction::AddressAppend(_) => { + proof_timings.address_append_proof_ms += result.proof_duration_ms; + proof_timings.address_append_round_trip_ms += result.round_trip_ms; + } + } + + cache + .add_proof(result.seq, result.old_root, result.new_root, instruction) + .await; + proofs_cached += 1; + } + Err(e) => { + warn!( + "Proof generation failed during pre-warm for seq={}: {}", + result.seq, e + ); + } + } + } + + cache.finish_warming().await; + + if proofs_cached < jobs_sent { + warn!( + "Pre-warmed {} proofs but expected {} for tree {}", + proofs_cached, jobs_sent, self.context.merkle_tree + ); + } else { + info!( + "Pre-warmed {} proofs for tree {} (zkp_batch_size={}, items={})", + proofs_cached, + self.context.merkle_tree, + self.zkp_batch_size, + proofs_cached * self.zkp_batch_size as usize + ); + } + + let mut metrics = ProcessingMetrics::default(); + if timings.append_count > 0 { + metrics.append = CircuitMetrics { + circuit_inputs_duration: timings.append_circuit_inputs, + proof_generation_duration: proof_timings.append_proof_duration(), + round_trip_duration: proof_timings.append_round_trip_duration(), + }; + } + if timings.nullify_count > 0 { + metrics.nullify = CircuitMetrics { + circuit_inputs_duration: timings.nullify_circuit_inputs, + proof_generation_duration: proof_timings.nullify_proof_duration(), + round_trip_duration: proof_timings.nullify_round_trip_duration(), + }; + } + if timings.address_append_count > 0 { + metrics.address_append = CircuitMetrics { + circuit_inputs_duration: timings.address_append_circuit_inputs, + proof_generation_duration: proof_timings.address_append_proof_duration(), + round_trip_duration: proof_timings.address_append_round_trip_duration(), + }; + } + + Ok(ProcessingResult { + items_processed: proofs_cached * self.zkp_batch_size as usize, + metrics, + }) + } + + pub fn current_root(&self) -> &[u8; 32] { + &self.current_root + } +} diff --git a/forester/src/processor/v2/proof_cache.rs b/forester/src/processor/v2/proof_cache.rs new file mode 100644 index 0000000000..123b4acf98 --- /dev/null +++ b/forester/src/processor/v2/proof_cache.rs @@ -0,0 +1,294 @@ +use std::collections::{BTreeMap, VecDeque}; + +use solana_sdk::pubkey::Pubkey; +use tokio::sync::Mutex; +use tracing::{debug, info, warn}; + +use super::tx_sender::BatchInstruction; + +const DEFAULT_MAX_CACHED_PROOFS: usize = 256; + +#[derive(Debug, Clone)] +pub struct CachedProof { + pub seq: u64, + pub old_root: [u8; 32], + pub new_root: [u8; 32], + pub instruction: BatchInstruction, + /// Number of ZKP batch instructions represented by this proof. + pub items: usize, +} + +#[derive(Debug)] +pub struct ProofCache { + tree: Pubkey, + base_root: [u8; 32], + proofs: VecDeque, + warming_proofs: BTreeMap, + is_warming: bool, + max_proofs: usize, +} + +impl ProofCache { + pub fn new(tree: Pubkey) -> Self { + Self { + tree, + base_root: [0u8; 32], + proofs: VecDeque::new(), + warming_proofs: BTreeMap::new(), + is_warming: false, + max_proofs: DEFAULT_MAX_CACHED_PROOFS, + } + } + + pub fn start_warming(&mut self, base_root: [u8; 32]) { + debug!( + "Starting cache warm-up for tree {} with root {:?}", + self.tree, + &base_root[..4] + ); + self.base_root = base_root; + self.proofs.clear(); + self.warming_proofs.clear(); + self.is_warming = true; + } + + pub fn add_proof( + &mut self, + seq: u64, + old_root: [u8; 32], + new_root: [u8; 32], + instruction: BatchInstruction, + ) { + if !self.is_warming { + warn!("Attempted to add proof to cache that is not warming"); + return; + } + if self.warming_proofs.contains_key(&seq) { + warn!( + "Duplicate cached proof seq={} for tree {}, ignoring", + seq, self.tree + ); + return; + } + + let items = instruction.items_count(); + self.warming_proofs.insert( + seq, + CachedProof { + seq, + old_root, + new_root, + instruction, + items, + }, + ); + + while self.warming_proofs.len() > self.max_proofs { + let Some((&last_seq, _)) = self.warming_proofs.last_key_value() else { + break; + }; + self.warming_proofs.remove(&last_seq); + warn!( + "Proof cache warm-up limit reached for tree {} (max={}), dropping newest seq={}", + self.tree, self.max_proofs, last_seq + ); + } + debug!( + "Cached proof seq={} for tree {} (total cached: {})", + seq, + self.tree, + self.warming_proofs.len() + ); + } + + pub fn finish_warming(&mut self) { + self.is_warming = false; + + if self.warming_proofs.is_empty() { + self.proofs.clear(); + info!( + "Cache warm-up complete for tree {}: 0 proofs cached with root {:?}", + self.tree, + &self.base_root[..4] + ); + return; + } + + if let Some(first) = self.warming_proofs.values().next() { + if self.base_root != [0u8; 32] && self.base_root != first.old_root { + warn!( + "First cached proof root mismatch for tree {}: base_root={:?}, proof.old_root={:?} (seq={})", + self.tree, + &self.base_root[..4], + &first.old_root[..4], + first.seq + ); + } + } + + self.proofs = self.warming_proofs.values().cloned().collect(); + self.warming_proofs.clear(); + + info!( + "Cache warm-up complete for tree {}: {} proofs cached with root {:?}", + self.tree, + self.proofs.len(), + &self.base_root[..4] + ); + } + + pub fn take_if_valid(&mut self, current_root: &[u8; 32]) -> Option> { + if self.proofs.is_empty() || self.is_warming { + return None; + } + + let mut skipped = 0; + while let Some(proof) = self.proofs.front() { + if proof.old_root == *current_root { + break; + } + if proof.new_root == *current_root { + self.proofs.pop_front(); + skipped += 1; + continue; + } + self.proofs.pop_front(); + skipped += 1; + } + + if skipped > 0 { + debug!( + "Skipped {} stale cached proofs for tree {} (on-chain already advanced)", + skipped, self.tree + ); + } + + if self.proofs.is_empty() { + debug!( + "Cache empty after skipping stale proofs for tree {} (current_root {:?})", + self.tree, + ¤t_root[..4] + ); + return None; + } + + let mut expected = *current_root; + let mut taken: Vec = Vec::new(); + + while let Some(proof) = self.proofs.pop_front() { + if proof.old_root != expected { + warn!( + "Cache chain broken for tree {} at seq {}: expected root {:?}, got {:?}. Dropping remaining {} proofs.", + self.tree, + proof.seq, + &expected[..4], + &proof.old_root[..4], + self.proofs.len() + ); + self.proofs.clear(); + break; + } + expected = proof.new_root; + taken.push(proof); + } + + if taken.is_empty() { + return None; + } + + info!( + "Using {} cached proofs for tree {} starting at root {:?} ending at {:?}{}", + taken.len(), + self.tree, + ¤t_root[..4], + &expected[..4], + if skipped > 0 { + format!(" (skipped {} stale)", skipped) + } else { + String::new() + } + ); + Some(taken) + } + + pub fn len(&self) -> usize { + self.proofs.len() + } + + pub fn is_empty(&self) -> bool { + self.proofs.is_empty() + } + + pub fn is_warming(&self) -> bool { + self.is_warming + } + + pub fn base_root(&self) -> &[u8; 32] { + &self.base_root + } + + pub fn clear(&mut self) { + self.proofs.clear(); + self.warming_proofs.clear(); + self.is_warming = false; + } +} + +pub struct SharedProofCache { + inner: Mutex, +} + +impl std::fmt::Debug for SharedProofCache { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SharedProofCache").finish_non_exhaustive() + } +} + +impl SharedProofCache { + pub fn new(tree: Pubkey) -> Self { + Self { + inner: Mutex::new(ProofCache::new(tree)), + } + } + + pub async fn start_warming(&self, base_root: [u8; 32]) { + self.inner.lock().await.start_warming(base_root); + } + + pub async fn add_proof( + &self, + seq: u64, + old_root: [u8; 32], + new_root: [u8; 32], + instruction: BatchInstruction, + ) { + self.inner + .lock() + .await + .add_proof(seq, old_root, new_root, instruction); + } + + pub async fn finish_warming(&self) { + self.inner.lock().await.finish_warming(); + } + + pub async fn take_if_valid(&self, current_root: &[u8; 32]) -> Option> { + self.inner.lock().await.take_if_valid(current_root) + } + + pub async fn is_warming(&self) -> bool { + self.inner.lock().await.is_warming() + } + + pub async fn len(&self) -> usize { + self.inner.lock().await.len() + } + + pub async fn is_empty(&self) -> bool { + self.inner.lock().await.is_empty() + } + + pub async fn clear(&self) { + self.inner.lock().await.clear(); + } +} diff --git a/forester/src/processor/v2/proof_worker.rs b/forester/src/processor/v2/proof_worker.rs new file mode 100644 index 0000000000..69f8bd8c29 --- /dev/null +++ b/forester/src/processor/v2/proof_worker.rs @@ -0,0 +1,417 @@ +use std::{sync::Arc, time::Duration}; + +use async_channel::Receiver; +use light_batched_merkle_tree::merkle_tree::{ + InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, +}; +use light_prover_client::{ + errors::ProverClientError, + proof::ProofResult, + proof_client::{ProofClient, SubmitProofResult}, + proof_types::{ + batch_address_append::BatchAddressAppendInputs, + batch_append::{BatchAppendInputsJson, BatchAppendsCircuitInputs}, + batch_update::BatchUpdateCircuitInputs, + }, +}; +use tokio::sync::mpsc; +use tracing::{debug, error, warn}; + +use crate::processor::v2::{tx_sender::BatchInstruction, ProverConfig}; + +#[derive(Debug, Clone)] +pub enum ProofInput { + Append(BatchAppendsCircuitInputs), + Nullify(BatchUpdateCircuitInputs), + AddressAppend(BatchAddressAppendInputs), +} + +impl ProofInput { + fn circuit_type(&self) -> &'static str { + match self { + ProofInput::Append(_) => "append", + ProofInput::Nullify(_) => "update", + ProofInput::AddressAppend(_) => "address_append", + } + } + + fn to_json(&self, tree_id: &str, batch_index: u64) -> String { + match self { + ProofInput::Append(inputs) => BatchAppendInputsJson::from_inputs(inputs) + .with_tree_id(tree_id.to_string()) + .with_batch_index(batch_index) + .to_string(), + ProofInput::Nullify(inputs) => { + use light_prover_client::proof_types::batch_update::BatchUpdateProofInputsJson; + BatchUpdateProofInputsJson::from_update_inputs(inputs) + .with_tree_id(tree_id.to_string()) + .with_batch_index(batch_index) + .to_string() + } + ProofInput::AddressAppend(inputs) => { + use light_prover_client::proof_types::batch_address_append::BatchAddressAppendInputsJson; + BatchAddressAppendInputsJson::from_inputs(inputs) + .with_tree_id(tree_id.to_string()) + .with_batch_index(batch_index) + .to_string() + } + } + } + + fn new_root_bytes(&self) -> crate::Result<[u8; 32]> { + match self { + ProofInput::Append(inputs) => { + let biguint = inputs.new_root.to_biguint().ok_or_else(|| { + anyhow::anyhow!("Failed to convert append new_root to biguint") + })?; + light_hasher::bigint::bigint_to_be_bytes_array::<32>(&biguint).map_err(Into::into) + } + ProofInput::Nullify(inputs) => { + let biguint = inputs.new_root.to_biguint().ok_or_else(|| { + anyhow::anyhow!("Failed to convert nullify new_root to biguint") + })?; + light_hasher::bigint::bigint_to_be_bytes_array::<32>(&biguint).map_err(Into::into) + } + ProofInput::AddressAppend(inputs) => { + light_hasher::bigint::bigint_to_be_bytes_array::<32>(&inputs.new_root) + .map_err(Into::into) + } + } + } + + fn old_root_bytes(&self) -> crate::Result<[u8; 32]> { + match self { + ProofInput::Append(inputs) => { + let biguint = inputs.old_root.to_biguint().ok_or_else(|| { + anyhow::anyhow!("Failed to convert append old_root to biguint") + })?; + light_hasher::bigint::bigint_to_be_bytes_array::<32>(&biguint).map_err(Into::into) + } + ProofInput::Nullify(inputs) => { + let biguint = inputs.old_root.to_biguint().ok_or_else(|| { + anyhow::anyhow!("Failed to convert nullify old_root to biguint") + })?; + light_hasher::bigint::bigint_to_be_bytes_array::<32>(&biguint).map_err(Into::into) + } + ProofInput::AddressAppend(inputs) => { + light_hasher::bigint::bigint_to_be_bytes_array::<32>(&inputs.old_root) + .map_err(Into::into) + } + } + } +} + +pub struct ProofJob { + pub(crate) seq: u64, + pub(crate) inputs: ProofInput, + pub(crate) result_tx: mpsc::Sender, + /// Tree pubkey for fair queuing - used to prevent starvation when multiple trees have proofs pending + pub(crate) tree_id: String, +} + +#[derive(Debug)] +pub struct ProofJobResult { + pub(crate) seq: u64, + pub(crate) result: Result, + pub(crate) old_root: [u8; 32], + pub(crate) new_root: [u8; 32], + /// Pure proof generation time in milliseconds (from prover server). + pub(crate) proof_duration_ms: u64, + /// Total round-trip time in milliseconds (submit to result, includes queue wait). + pub(crate) round_trip_ms: u64, + /// When this proof job was submitted (for end-to-end latency tracking). + pub(crate) submitted_at: std::time::Instant, +} + +struct ProofClients { + append_client: ProofClient, + nullify_client: ProofClient, + address_append_client: ProofClient, +} + +impl ProofClients { + fn new(config: &ProverConfig) -> Self { + Self { + append_client: ProofClient::with_config( + config.append_url.clone(), + config.polling_interval, + config.max_wait_time, + config.api_key.clone(), + ), + nullify_client: ProofClient::with_config( + config.update_url.clone(), + config.polling_interval, + config.max_wait_time, + config.api_key.clone(), + ), + address_append_client: ProofClient::with_config( + config.address_append_url.clone(), + config.polling_interval, + config.max_wait_time, + config.api_key.clone(), + ), + } + } + + fn get_client(&self, input: &ProofInput) -> &ProofClient { + match input { + ProofInput::Append(_) => &self.append_client, + ProofInput::Nullify(_) => &self.nullify_client, + ProofInput::AddressAppend(_) => &self.address_append_client, + } + } +} + +pub fn spawn_proof_workers(config: &ProverConfig) -> async_channel::Sender { + let (job_tx, job_rx) = async_channel::bounded::(256); + let clients = Arc::new(ProofClients::new(config)); + tokio::spawn(async move { run_proof_pipeline(job_rx, clients).await }); + job_tx +} + +async fn run_proof_pipeline( + job_rx: Receiver, + clients: Arc, +) -> crate::Result<()> { + while let Ok(job) = job_rx.recv().await { + let clients = clients.clone(); + // Spawn immediately so we don't block receiving the next job + // while waiting for HTTP submission + tokio::spawn(async move { + submit_and_poll_proof(clients, job).await; + }); + } + + Ok(()) +} + +async fn submit_and_poll_proof(clients: Arc, job: ProofJob) { + let client = clients.get_client(&job.inputs); + // Use seq as batch_index for ordering in the prover queue + let inputs_json = job.inputs.to_json(&job.tree_id, job.seq); + let circuit_type = job.inputs.circuit_type(); + + let round_trip_start = std::time::Instant::now(); + + match client.submit_proof_async(inputs_json, circuit_type).await { + Ok(SubmitProofResult::Queued(job_id)) => { + debug!( + "Submitted proof job seq={} type={} job_id={}", + job.seq, circuit_type, job_id + ); + + poll_and_send_result( + clients, + job_id, + job.seq, + job.inputs, + job.tree_id, + job.result_tx, + round_trip_start, + ) + .await; + } + Ok(SubmitProofResult::Immediate(proof)) => { + let round_trip_ms = round_trip_start.elapsed().as_millis() as u64; + debug!( + "Got immediate proof for seq={} type={} round_trip={}ms", + job.seq, circuit_type, round_trip_ms + ); + + let result = + build_proof_result(job.seq, &job.inputs, proof, round_trip_ms, round_trip_start); + let _ = job.result_tx.send(result).await; + } + Err(e) => { + error!( + "Failed to submit proof job seq={} type={}: {}", + job.seq, circuit_type, e + ); + + let result = ProofJobResult { + seq: job.seq, + result: Err(format!("Submit failed: {}", e)), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: 0, + round_trip_ms: 0, + submitted_at: round_trip_start, + }; + let _ = job.result_tx.send(result).await; + } + } +} + +async fn poll_and_send_result( + clients: Arc, + job_id: String, + seq: u64, + inputs: ProofInput, + tree_id: String, + result_tx: mpsc::Sender, + round_trip_start: std::time::Instant, +) { + let client = clients.get_client(&inputs); + + // Poll; on job_not_found, resubmit once and poll the new job. + let result = match client.poll_proof_completion(job_id.clone()).await { + Ok(proof) => { + let round_trip_ms = round_trip_start.elapsed().as_millis() as u64; + debug!( + "Proof completed for seq={} job_id={} round_trip={}ms proof={}ms", + seq, job_id, round_trip_ms, proof.proof_duration_ms + ); + build_proof_result(seq, &inputs, proof, round_trip_ms, round_trip_start) + } + Err(e) if is_job_not_found(&e) => { + warn!( + "Proof polling got job_not_found for seq={} job_id={}; retrying submit once", + seq, job_id + ); + tokio::time::sleep(Duration::from_millis(200)).await; + + let inputs_json = inputs.to_json(&tree_id, seq); + let circuit_type = inputs.circuit_type(); + match client.submit_proof_async(inputs_json, circuit_type).await { + Ok(SubmitProofResult::Queued(new_job_id)) => { + debug!( + "Resubmitted proof job seq={} type={} new_job_id={}", + seq, circuit_type, new_job_id + ); + match client.poll_proof_completion(new_job_id.clone()).await { + Ok(proof) => { + let round_trip_ms = round_trip_start.elapsed().as_millis() as u64; + debug!( + "Proof completed after retry for seq={} job_id={} round_trip={}ms", + seq, new_job_id, round_trip_ms + ); + build_proof_result(seq, &inputs, proof, round_trip_ms, round_trip_start) + } + Err(e2) => ProofJobResult { + seq, + result: Err(format!( + "Proof failed after retry job_id={}: {}", + new_job_id, e2 + )), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: 0, + round_trip_ms: 0, + submitted_at: round_trip_start, + }, + } + } + Ok(SubmitProofResult::Immediate(proof)) => { + let round_trip_ms = round_trip_start.elapsed().as_millis() as u64; + debug!( + "Immediate proof after retry for seq={} type={} round_trip={}ms", + seq, circuit_type, round_trip_ms + ); + build_proof_result(seq, &inputs, proof, round_trip_ms, round_trip_start) + } + Err(e_submit) => ProofJobResult { + seq, + result: Err(format!("Proof retry submit failed: {}", e_submit)), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: 0, + round_trip_ms: 0, + submitted_at: round_trip_start, + }, + } + } + Err(e) => { + warn!( + "Proof polling failed for seq={} job_id={}: {}", + seq, job_id, e + ); + ProofJobResult { + seq, + result: Err(format!("Proof failed: {}", e)), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: 0, + round_trip_ms: 0, + submitted_at: round_trip_start, + } + } + }; + + if result_tx.send(result).await.is_err() { + debug!("Result channel closed for job seq={}", seq); + } +} + +fn is_job_not_found(err: &ProverClientError) -> bool { + matches!( + err, + ProverClientError::ProverServerError(msg) if msg.contains("job_not_found") + ) +} + +fn build_proof_result( + seq: u64, + inputs: &ProofInput, + proof_with_timing: ProofResult, + round_trip_ms: u64, + submitted_at: std::time::Instant, +) -> ProofJobResult { + let new_root = match inputs.new_root_bytes() { + Ok(root) => root, + Err(e) => { + return ProofJobResult { + seq, + result: Err(format!("Failed to get new root: {}", e)), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: proof_with_timing.proof_duration_ms, + round_trip_ms, + submitted_at, + }; + } + }; + let old_root = match inputs.old_root_bytes() { + Ok(root) => root, + Err(e) => { + return ProofJobResult { + seq, + result: Err(format!("Failed to get old root: {}", e)), + old_root: [0u8; 32], + new_root: [0u8; 32], + proof_duration_ms: proof_with_timing.proof_duration_ms, + round_trip_ms, + submitted_at, + }; + } + }; + + let proof = proof_with_timing.proof; + let instruction = match inputs { + ProofInput::Append(_) => BatchInstruction::Append(vec![InstructionDataBatchAppendInputs { + new_root, + compressed_proof: proof.into(), + }]), + ProofInput::Nullify(_) => { + BatchInstruction::Nullify(vec![InstructionDataBatchNullifyInputs { + new_root, + compressed_proof: proof.into(), + }]) + } + ProofInput::AddressAppend(_) => BatchInstruction::AddressAppend(vec![ + light_batched_merkle_tree::merkle_tree::InstructionDataAddressAppendInputs { + new_root, + compressed_proof: proof.into(), + }, + ]), + }; + + ProofJobResult { + seq, + old_root, + new_root, + result: Ok(instruction), + proof_duration_ms: proof_with_timing.proof_duration_ms, + round_trip_ms, + submitted_at, + } +} diff --git a/forester/src/processor/v2/root_guard.rs b/forester/src/processor/v2/root_guard.rs new file mode 100644 index 0000000000..85dc1d564a --- /dev/null +++ b/forester/src/processor/v2/root_guard.rs @@ -0,0 +1,153 @@ +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RootReconcileDecision { + Proceed, + WaitForIndexer, + ResetToOnchainAndProceed([u8; 32]), + ResetToOnchainAndStop([u8; 32]), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AlignmentDecision { + /// No overlap; safe to process the batch. + Process, + /// Batch overlaps already-processed items; skip it. + SkipOverlap, + /// There's a gap between what's expected and where this batch starts. + Gap, + /// Local staging tree is stale relative to the indexer snapshot; invalidate state. + StaleTree, +} + +/// Decide how to reconcile roots after fetching an indexer snapshot root and an on-chain root. +/// +/// Inputs: +/// - `expected_root`: the processor's local expected root (may be zero/uninitialized) +/// - `indexer_root`: the indexer's snapshot root for the fetched queue data +/// - `onchain_root`: the authoritative on-chain root +pub fn reconcile_roots( + expected_root: [u8; 32], + indexer_root: [u8; 32], + onchain_root: [u8; 32], +) -> RootReconcileDecision { + if expected_root == [0u8; 32] || indexer_root == expected_root { + return RootReconcileDecision::Proceed; + } + + if onchain_root == expected_root { + return RootReconcileDecision::WaitForIndexer; + } + + if indexer_root == onchain_root { + return RootReconcileDecision::ResetToOnchainAndProceed(onchain_root); + } + + RootReconcileDecision::ResetToOnchainAndStop(onchain_root) +} + +/// Decide whether a particular batch should be processed given: +/// - where the indexer snapshot starts (`data_start_index`) +/// - where the staging tree currently is (`tree_next_index`) +/// - the batch start offset within the snapshot (`start`) +/// +/// The return value is intentionally coarse-grained so callers can decide whether to retry, +/// invalidate caches, or simply skip work. +pub fn reconcile_alignment( + tree_next_index: usize, + data_start_index: usize, + start: usize, +) -> AlignmentDecision { + if data_start_index > tree_next_index { + return AlignmentDecision::StaleTree; + } + + let absolute_index = data_start_index + start; + + if absolute_index < tree_next_index { + return AlignmentDecision::SkipOverlap; + } + if absolute_index > tree_next_index { + return AlignmentDecision::Gap; + } + + AlignmentDecision::Process +} + +#[cfg(test)] +mod tests { + use super::*; + + fn root(byte: u8) -> [u8; 32] { + [byte; 32] + } + + #[test] + fn proceeds_when_expected_is_zero() { + assert_eq!( + reconcile_roots(root(0), root(1), root(2)), + RootReconcileDecision::Proceed + ); + } + + #[test] + fn proceeds_when_expected_matches_indexer() { + assert_eq!( + reconcile_roots(root(9), root(9), root(8)), + RootReconcileDecision::Proceed + ); + } + + #[test] + fn waits_when_onchain_confirms_expected() { + assert_eq!( + reconcile_roots(root(7), root(6), root(7)), + RootReconcileDecision::WaitForIndexer + ); + } + + #[test] + fn resets_and_proceeds_when_indexer_matches_onchain() { + assert_eq!( + reconcile_roots(root(7), root(6), root(6)), + RootReconcileDecision::ResetToOnchainAndProceed(root(6)) + ); + } + + #[test] + fn resets_and_stops_on_three_way_divergence() { + assert_eq!( + reconcile_roots(root(7), root(6), root(5)), + RootReconcileDecision::ResetToOnchainAndStop(root(5)) + ); + } + + #[test] + fn alignment_stale_when_data_starts_after_tree() { + assert_eq!(reconcile_alignment(10, 11, 0), AlignmentDecision::StaleTree); + } + + #[test] + fn alignment_skips_full_overlap() { + assert_eq!( + reconcile_alignment(10, 0, 0), + AlignmentDecision::SkipOverlap + ); + } + + #[test] + fn alignment_skips_partial_overlap() { + assert_eq!( + reconcile_alignment(10, 0, 8), + AlignmentDecision::SkipOverlap + ); + } + + #[test] + fn alignment_processes_when_no_overlap() { + assert_eq!(reconcile_alignment(10, 0, 10), AlignmentDecision::Process); + } + + #[test] + fn alignment_reports_gap_when_batch_starts_after_expected() { + assert_eq!(reconcile_alignment(10, 0, 12), AlignmentDecision::Gap); + } +} diff --git a/forester/src/processor/v2/state/helpers.rs b/forester/src/processor/v2/state/helpers.rs deleted file mode 100644 index 49ee2a07a1..0000000000 --- a/forester/src/processor/v2/state/helpers.rs +++ /dev/null @@ -1,77 +0,0 @@ -use anyhow::anyhow; -use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; -use light_client::{ - indexer::{Indexer, QueueElementsV2Options}, - rpc::Rpc, -}; -use tracing::warn; - -use crate::processor::v2::BatchContext; - -/// Fetches zkp_batch_size from on-chain merkle tree account (called once at startup) -pub async fn fetch_zkp_batch_size(context: &BatchContext) -> crate::Result { - let rpc = context.rpc_pool.get_connection().await?; - let mut account = rpc - .get_account(context.merkle_tree) - .await? - .ok_or_else(|| anyhow!("Merkle tree account not found"))?; - - let tree = BatchedMerkleTreeAccount::state_from_bytes( - account.data.as_mut_slice(), - &context.merkle_tree.into(), - )?; - - let batch_index = tree.queue_batches.pending_batch_index; - let batch = tree - .queue_batches - .batches - .get(batch_index as usize) - .ok_or_else(|| anyhow!("Batch not found"))?; - - Ok(batch.zkp_batch_size) -} - -pub async fn fetch_batches( - context: &BatchContext, - output_start_index: Option, - input_start_index: Option, - fetch_len: u64, - zkp_batch_size: u64, -) -> crate::Result> { - let fetch_len_u16: u16 = match fetch_len.try_into() { - Ok(v) => v, - Err(_) => { - warn!( - "fetch_len {} exceeds u16::MAX, clamping to {}", - fetch_len, - u16::MAX - ); - u16::MAX - } - }; - let zkp_batch_size_u16: u16 = match zkp_batch_size.try_into() { - Ok(v) => v, - Err(_) => { - warn!( - "zkp_batch_size {} exceeds u16::MAX, clamping to {}", - zkp_batch_size, - u16::MAX - ); - u16::MAX - } - }; - - let mut rpc = context.rpc_pool.get_connection().await?; - let indexer = rpc.indexer_mut()?; - let options = QueueElementsV2Options::default() - .with_output_queue(output_start_index, Some(fetch_len_u16)) - .with_output_queue_batch_size(Some(zkp_batch_size_u16)) - .with_input_queue(input_start_index, Some(fetch_len_u16)) - .with_input_queue_batch_size(Some(zkp_batch_size_u16)); - - let res = indexer - .get_queue_elements(context.merkle_tree.to_bytes(), options, None) - .await?; - - Ok(res.value.state_queue) -} diff --git a/forester/src/processor/v2/state/mod.rs b/forester/src/processor/v2/state/mod.rs deleted file mode 100644 index 3ab990aa8a..0000000000 --- a/forester/src/processor/v2/state/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -mod helpers; -mod proof_worker; -mod supervisor; -mod tx_sender; - -pub use supervisor::{ProcessQueueUpdate, QueueWork, StateSupervisor, UpdateEligibility}; diff --git a/forester/src/processor/v2/state/proof_worker.rs b/forester/src/processor/v2/state/proof_worker.rs deleted file mode 100644 index 4ac4605528..0000000000 --- a/forester/src/processor/v2/state/proof_worker.rs +++ /dev/null @@ -1,132 +0,0 @@ -use async_channel::Receiver; -use light_batched_merkle_tree::merkle_tree::{ - InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, -}; -use light_prover_client::{ - proof_client::ProofClient, - proof_types::{ - batch_append::BatchAppendsCircuitInputs, batch_update::BatchUpdateCircuitInputs, - }, -}; -use tokio::sync::mpsc; -use tracing::{debug, info, trace, warn}; - -use crate::processor::v2::{state::tx_sender::BatchInstruction, ProverConfig}; - -#[derive(Debug)] -pub enum ProofInput { - Append(BatchAppendsCircuitInputs), - Nullify(BatchUpdateCircuitInputs), -} - -pub struct ProofJob { - pub(crate) seq: u64, - pub(crate) inputs: ProofInput, - pub(crate) result_tx: mpsc::Sender, -} - -#[derive(Debug)] -pub struct ProofResult { - pub(crate) seq: u64, - pub(crate) instruction: BatchInstruction, -} - -pub fn spawn_proof_workers( - num_workers: usize, - config: ProverConfig, -) -> async_channel::Sender { - // Enforce minimum of 1 worker to prevent zero-capacity channels and no workers - let num_workers = if num_workers == 0 { - warn!("spawn_proof_workers called with num_workers=0, using 1 instead"); - 1 - } else { - num_workers - }; - - let channel_capacity = num_workers * 2; - let (job_tx, job_rx) = async_channel::bounded::(channel_capacity); - - for worker_id in 0..num_workers { - let job_rx = job_rx.clone(); - let config = config.clone(); - tokio::spawn(async move { run_proof_worker(worker_id, job_rx, config).await }); - } - - info!("Spawned {} proof workers", num_workers); - job_tx -} - -async fn run_proof_worker( - worker_id: usize, - job_rx: Receiver, - config: ProverConfig, -) -> crate::Result<()> { - let append_client = ProofClient::with_config( - config.append_url, - config.polling_interval, - config.max_wait_time, - config.api_key.clone(), - ); - let nullify_client = ProofClient::with_config( - config.update_url, - config.polling_interval, - config.max_wait_time, - config.api_key, - ); - - trace!("ProofWorker {} started", worker_id); - - while let Ok(job) = job_rx.recv().await { - debug!("ProofWorker {} processing job seq={}", worker_id, job.seq); - - let result = match job.inputs { - ProofInput::Append(inputs) => { - match append_client.generate_batch_append_proof(inputs).await { - Ok((proof, new_root)) => ProofResult { - seq: job.seq, - instruction: BatchInstruction::Append(vec![ - InstructionDataBatchAppendInputs { - new_root, - compressed_proof: proof.into(), - }, - ]), - }, - Err(e) => { - warn!("ProofWorker {} append proof failed: {}", worker_id, e); - continue; - } - } - } - ProofInput::Nullify(inputs) => { - match nullify_client.generate_batch_update_proof(inputs).await { - Ok((proof, new_root)) => ProofResult { - seq: job.seq, - instruction: BatchInstruction::Nullify(vec![ - InstructionDataBatchNullifyInputs { - new_root, - compressed_proof: proof.into(), - }, - ]), - }, - Err(e) => { - warn!("ProofWorker {} nullify proof failed: {}", worker_id, e); - continue; - } - } - } - }; - - // Send result via the job's own channel - if it's closed, just continue to next job - if job.result_tx.send(result).await.is_err() { - debug!( - "ProofWorker {} result channel closed for job seq={}, continuing", - worker_id, job.seq - ); - } else { - debug!("ProofWorker {} completed job seq={}", worker_id, job.seq); - } - } - - trace!("ProofWorker {} shutting down", worker_id); - Ok(()) -} diff --git a/forester/src/processor/v2/state/supervisor.rs b/forester/src/processor/v2/state/supervisor.rs deleted file mode 100644 index a7d7d19eb3..0000000000 --- a/forester/src/processor/v2/state/supervisor.rs +++ /dev/null @@ -1,541 +0,0 @@ -use anyhow::anyhow; -use forester_utils::staging_tree::{BatchType, StagingTree}; -use kameo::{ - actor::{ActorRef, WeakActorRef}, - error::ActorStopReason, - message::Message, - Actor, -}; -use light_batched_merkle_tree::constants::DEFAULT_BATCH_STATE_TREE_HEIGHT; -use light_client::rpc::Rpc; -use light_compressed_account::QueueType; -use light_prover_client::proof_types::{ - batch_append::BatchAppendsCircuitInputs, batch_update::BatchUpdateCircuitInputs, -}; -use light_registry::protocol_config::state::EpochState; -use tokio::sync::mpsc; -use tracing::{debug, info, trace, warn}; - -use crate::processor::v2::{ - state::{ - helpers::{fetch_batches, fetch_zkp_batch_size}, - proof_worker::{spawn_proof_workers, ProofInput, ProofJob, ProofResult}, - tx_sender::TxSender, - }, - BatchContext, -}; - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum Phase { - Append, - Nullify, -} - -#[derive(Debug, Clone)] -pub struct QueueWork { - pub queue_type: QueueType, - pub queue_size: u64, -} - -#[derive(Debug, Clone)] -pub struct ProcessQueueUpdate { - pub queue_work: QueueWork, -} - -#[derive(Debug, Clone)] -pub struct UpdateEligibility { - pub end_slot: u64, -} - -struct WorkerPool { - job_tx: async_channel::Sender, -} - -pub struct StateSupervisor { - context: BatchContext, - staging_tree: Option, - current_root: [u8; 32], - next_index: u64, - zkp_batch_size: u64, - seq: u64, - worker_pool: Option, -} - -impl Actor for StateSupervisor { - type Args = BatchContext; - type Error = anyhow::Error; - - async fn on_start( - context: Self::Args, - _actor_ref: ActorRef, - ) -> Result { - info!( - "StateSupervisor actor starting for tree {}", - context.merkle_tree - ); - - // Fetch zkp_batch_size once from on-chain (this is static per tree) - let zkp_batch_size = fetch_zkp_batch_size(&context).await?; - info!( - "StateSupervisor fetched zkp_batch_size={} for tree {}", - zkp_batch_size, context.merkle_tree - ); - - Ok(Self { - context, - staging_tree: None, - current_root: [0u8; 32], - next_index: 0, - zkp_batch_size, - seq: 0, - worker_pool: None, - }) - } - - async fn on_stop( - &mut self, - _actor_ref: WeakActorRef, - _reason: ActorStopReason, - ) -> Result<(), Self::Error> { - info!( - "StateSupervisor actor stopping for tree {}", - self.context.merkle_tree - ); - Ok(()) - } -} - -impl Message for StateSupervisor { - type Reply = crate::Result; - - async fn handle( - &mut self, - msg: ProcessQueueUpdate, - _ctx: &mut kameo::message::Context, - ) -> Self::Reply { - self.process_queue_update(msg.queue_work).await - } -} - -impl Message for StateSupervisor { - type Reply = (); - - async fn handle( - &mut self, - msg: UpdateEligibility, - _ctx: &mut kameo::message::Context, - ) -> Self::Reply { - debug!( - "Updating eligibility end slot to {} for tree {}", - msg.end_slot, self.context.merkle_tree - ); - self.context - .forester_eligibility_end_slot - .store(msg.end_slot, std::sync::atomic::Ordering::Relaxed); - } -} - -impl StateSupervisor { - fn zkp_batch_size(&self) -> u64 { - self.zkp_batch_size - } - - /// Gets the leaves hashchain for a batch, returning an error if not found. - fn get_leaves_hashchain( - leaves_hash_chains: &[[u8; 32]], - batch_idx: usize, - ) -> crate::Result<[u8; 32]> { - leaves_hash_chains.get(batch_idx).copied().ok_or_else(|| { - anyhow!( - "Missing leaves_hash_chain for batch {} (available: {})", - batch_idx, - leaves_hash_chains.len() - ) - }) - } - - /// Computes the slice range for a batch given total length and start index. - fn batch_range(&self, total_len: usize, start: usize) -> std::ops::Range { - let end = (start + self.zkp_batch_size as usize).min(total_len); - start..end - } - - /// Finalizes a proof job by updating state and returning the job. - fn finish_job( - &mut self, - new_root: [u8; 32], - inputs: ProofInput, - result_tx: mpsc::Sender, - ) -> Option { - self.current_root = new_root; - let job_seq = self.seq; - self.seq += 1; - Some(ProofJob { - seq: job_seq, - inputs, - result_tx, - }) - } - - fn ensure_worker_pool(&mut self) { - if self.worker_pool.is_none() { - let num_workers = self.context.num_proof_workers.max(1); - let job_tx = spawn_proof_workers(num_workers, self.context.prover_config.clone()); - - info!( - "StateSupervisor spawned {} persistent proof workers for tree {}", - num_workers, self.context.merkle_tree - ); - - self.worker_pool = Some(WorkerPool { job_tx }); - } - } - - async fn process_queue_update(&mut self, queue_work: QueueWork) -> crate::Result { - debug!( - "StateSupervisor processing queue update for tree {} (hint: {} items)", - self.context.merkle_tree, queue_work.queue_size - ); - - // Check if we're still in the active phase before processing - let current_slot = self.context.slot_tracker.estimated_current_slot(); - let current_phase = self - .context - .epoch_phases - .get_current_epoch_state(current_slot); - - if current_phase != EpochState::Active { - debug!( - "Skipping queue update: not in active phase (current: {:?}, slot: {}, epoch: {})", - current_phase, current_slot, self.context.epoch - ); - return Ok(0); - } - - let zkp_batch_size = self.zkp_batch_size(); - if queue_work.queue_size < zkp_batch_size { - trace!( - "Queue size {} below zkp_batch_size {}, skipping", - queue_work.queue_size, - zkp_batch_size - ); - return Ok(0); - } - - let phase = match queue_work.queue_type { - QueueType::OutputStateV2 => Phase::Append, - QueueType::InputStateV2 => Phase::Nullify, - other => { - warn!("Unsupported queue type for state processing: {:?}", other); - return Ok(0); - } - }; - - let max_batches = (queue_work.queue_size / zkp_batch_size) as usize; - if max_batches == 0 { - return Ok(0); - } - - self.ensure_worker_pool(); - - let num_workers = self.context.num_proof_workers.max(1); - - let (proof_tx, proof_rx) = mpsc::channel(num_workers * 2); - - // Reset seq counter - TxSender always expects seq to start at 0 - self.seq = 0; - - let tx_sender_handle = TxSender::spawn( - self.context.clone(), - proof_rx, - self.zkp_batch_size(), - self.current_root, - ); - - let job_tx = self - .worker_pool - .as_ref() - .expect("worker pool should be initialized") - .job_tx - .clone(); - let jobs_sent = self - .enqueue_batches(phase, max_batches, job_tx, proof_tx) - .await?; - - let tx_processed = match tx_sender_handle.await { - Ok(res) => match res { - Ok(processed) => processed, - Err(e) => { - warn!("Tx sender error, resetting staging tree: {}", e); - self.reset_staging_tree(); - return Err(e); - } - }, - Err(e) => { - warn!("Tx sender join error, resetting staging tree: {}", e); - self.reset_staging_tree(); - return Err(anyhow!("Tx sender join error: {}", e)); - } - }; - - if tx_processed < jobs_sent * self.zkp_batch_size as usize { - debug!( - "Processed {} items but sent {} jobs (expected {}), some proofs may have failed", - tx_processed, - jobs_sent, - jobs_sent * self.zkp_batch_size as usize - ); - } - - Ok(tx_processed) - } - - fn reset_staging_tree(&mut self) { - info!( - "Resetting staging tree for tree {}", - self.context.merkle_tree - ); - self.staging_tree = None; - } - - fn build_staging_tree( - &mut self, - leaf_indices: &[u64], - leaves: &[[u8; 32]], - nodes: &[u64], - node_hashes: &[[u8; 32]], - initial_root: [u8; 32], - root_seq: u64, - ) -> crate::Result<()> { - self.staging_tree = Some(StagingTree::new( - leaf_indices, - leaves, - nodes, - node_hashes, - initial_root, - root_seq, - )?); - debug!("Built staging tree from indexer (seq={})", root_seq); - Ok(()) - } - - async fn enqueue_batches( - &mut self, - phase: Phase, - max_batches: usize, - job_tx: async_channel::Sender, - result_tx: mpsc::Sender, - ) -> crate::Result { - let zkp_batch_size = self.zkp_batch_size() as usize; - let total_needed = max_batches.saturating_mul(zkp_batch_size); - let fetch_len = total_needed as u64; - - let state_queue = - fetch_batches(&self.context, None, None, fetch_len, self.zkp_batch_size()).await?; - - let Some(state_queue) = state_queue else { - return Ok(0); - }; - - let mut jobs_sent = 0usize; - - match phase { - Phase::Append => { - let Some(output_batch) = state_queue.output_queue.as_ref() else { - return Ok(0); - }; - if output_batch.leaf_indices.is_empty() { - return Ok(0); - } - - self.current_root = state_queue.initial_root; - self.next_index = output_batch.next_index; - info!( - "Synced from indexer: root {:?}[..4], next_index {}", - &self.current_root[..4], - self.next_index - ); - - self.build_staging_tree( - &output_batch.leaf_indices, - &output_batch.old_leaves, - &state_queue.nodes, - &state_queue.node_hashes, - state_queue.initial_root, - state_queue.root_seq, - )?; - - let available = output_batch.leaf_indices.len(); - let num_slices = (available / zkp_batch_size).min(max_batches); - - for batch_idx in 0..num_slices { - let start = batch_idx * zkp_batch_size; - if let Some(job) = self - .build_append_job(batch_idx, &state_queue, start, result_tx.clone()) - .await? - { - job_tx.send(job).await?; - jobs_sent += 1; - } else { - break; - } - } - } - Phase::Nullify => { - let Some(input_batch) = state_queue.input_queue.as_ref() else { - return Ok(0); - }; - if input_batch.leaf_indices.is_empty() { - return Ok(0); - } - - self.current_root = state_queue.initial_root; - info!( - "Synced from indexer: root {:?}[..4]", - &self.current_root[..4] - ); - - self.build_staging_tree( - &input_batch.leaf_indices, - &input_batch.current_leaves, - &state_queue.nodes, - &state_queue.node_hashes, - state_queue.initial_root, - state_queue.root_seq, - )?; - - let available = input_batch.leaf_indices.len(); - let num_slices = (available / zkp_batch_size).min(max_batches); - - for batch_idx in 0..num_slices { - let start = batch_idx * zkp_batch_size; - if let Some(job) = self - .build_nullify_job(batch_idx, &state_queue, start, result_tx.clone()) - .await? - { - job_tx.send(job).await?; - jobs_sent += 1; - } else { - break; - } - } - } - } - - drop(result_tx); - - info!("Enqueued {} jobs for proof generation", jobs_sent); - Ok(jobs_sent) - } - - async fn build_append_job( - &mut self, - batch_idx: usize, - state_queue: &light_client::indexer::StateQueueData, - start: usize, - result_tx: mpsc::Sender, - ) -> crate::Result> { - let batch = state_queue - .output_queue - .as_ref() - .ok_or_else(|| anyhow!("Output queue not present in state queue"))?; - - let range = self.batch_range(batch.account_hashes.len(), start); - let leaves = batch.account_hashes[range.clone()].to_vec(); - let leaf_indices = batch.leaf_indices[range].to_vec(); - - let hashchain_idx = start / self.zkp_batch_size as usize; - let batch_seq = state_queue.root_seq + (batch_idx as u64) + 1; - - let staging = self.staging_tree.as_mut().ok_or_else(|| { - anyhow!( - "Staging tree not initialized for append job (batch_idx={})", - batch_idx - ) - })?; - let result = staging.process_batch_updates( - &leaf_indices, - &leaves, - BatchType::Append, - batch_idx, - batch_seq, - )?; - let new_root = result.new_root; - - let leaves_hashchain = - Self::get_leaves_hashchain(&batch.leaves_hash_chains, hashchain_idx)?; - let start_index = leaf_indices.first().copied().unwrap_or(0) as u32; - - let circuit_inputs = - BatchAppendsCircuitInputs::new::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>( - result.into(), - start_index, - leaves.clone(), - leaves_hashchain, - self.zkp_batch_size as u32, - ) - .map_err(|e| anyhow!("Failed to build append inputs: {}", e))?; - - self.next_index = self.next_index.saturating_add(self.zkp_batch_size); - Ok(self.finish_job(new_root, ProofInput::Append(circuit_inputs), result_tx)) - } - - async fn build_nullify_job( - &mut self, - batch_idx: usize, - state_queue: &light_client::indexer::StateQueueData, - start: usize, - result_tx: mpsc::Sender, - ) -> crate::Result> { - let batch = state_queue - .input_queue - .as_ref() - .ok_or_else(|| anyhow!("Input queue not present in state queue"))?; - - let range = self.batch_range(batch.account_hashes.len(), start); - let account_hashes = batch.account_hashes[range.clone()].to_vec(); - let tx_hashes = batch.tx_hashes[range.clone()].to_vec(); - let nullifiers = batch.nullifiers[range.clone()].to_vec(); - let leaf_indices = batch.leaf_indices[range].to_vec(); - let hashchain_idx = start / self.zkp_batch_size as usize; - let batch_seq = state_queue.root_seq + (batch_idx as u64) + 1; - - let staging = self.staging_tree.as_mut().ok_or_else(|| { - anyhow!( - "Staging tree not initialized for nullify job (batch_idx={})", - batch_idx - ) - })?; - let result = staging.process_batch_updates( - &leaf_indices, - &nullifiers, - BatchType::Nullify, - batch_idx, - batch_seq, - )?; - info!( - "nullify batch {} root {:?}[..4] => {:?}[..4]", - batch_idx, - &result.old_root[..4], - &result.new_root[..4] - ); - - let new_root = result.new_root; - let leaves_hashchain = - Self::get_leaves_hashchain(&batch.leaves_hash_chains, hashchain_idx)?; - let path_indices: Vec = leaf_indices.iter().map(|idx| *idx as u32).collect(); - - let circuit_inputs = - BatchUpdateCircuitInputs::new::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>( - result.into(), - tx_hashes, - account_hashes, - leaves_hashchain, - path_indices, - self.zkp_batch_size as u32, - ) - .map_err(|e| anyhow!("Failed to build nullify inputs: {}", e))?; - - Ok(self.finish_job(new_root, ProofInput::Nullify(circuit_inputs), result_tx)) - } -} diff --git a/forester/src/processor/v2/state/tx_sender.rs b/forester/src/processor/v2/state/tx_sender.rs deleted file mode 100644 index d6df3127f0..0000000000 --- a/forester/src/processor/v2/state/tx_sender.rs +++ /dev/null @@ -1,124 +0,0 @@ -use std::collections::BTreeMap; - -use borsh::BorshSerialize; -use light_batched_merkle_tree::merkle_tree::{ - InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, -}; -use light_client::rpc::Rpc; -use light_registry::account_compression_cpi::sdk::{ - create_batch_append_instruction, create_batch_nullify_instruction, -}; -use solana_sdk::signature::Signer; -use tokio::{sync::mpsc, task::JoinHandle}; -use tracing::{info, warn}; - -use crate::{ - errors::ForesterError, - processor::v2::{ - common::send_transaction_batch, state::proof_worker::ProofResult, BatchContext, - }, -}; - -#[derive(Debug)] -pub enum BatchInstruction { - Append(Vec), - Nullify(Vec), -} - -pub struct TxSender { - context: BatchContext, - expected_seq: u64, - buffer: BTreeMap, - zkp_batch_size: u64, - last_seen_root: [u8; 32], -} - -impl TxSender { - pub(crate) fn spawn( - context: BatchContext, - proof_rx: mpsc::Receiver, - zkp_batch_size: u64, - last_seen_root: [u8; 32], - ) -> JoinHandle> { - let sender = Self { - context, - expected_seq: 0, - buffer: BTreeMap::new(), - zkp_batch_size, - last_seen_root, - }; - - tokio::spawn(async move { sender.run(proof_rx).await }) - } - - async fn run(mut self, mut proof_rx: mpsc::Receiver) -> crate::Result { - let mut processed = 0usize; - - while let Some(result) = proof_rx.recv().await { - self.buffer.insert(result.seq, result.instruction); - - while let Some(instr) = self.buffer.remove(&self.expected_seq) { - let (instructions, expected_root) = match &instr { - BatchInstruction::Append(proofs) => { - let ix = proofs - .iter() - .map(|data| { - Ok(create_batch_append_instruction( - self.context.authority.pubkey(), - self.context.derivation, - self.context.merkle_tree, - self.context.output_queue, - self.context.epoch, - data.try_to_vec()?, - )) - }) - .collect::>>()?; - (ix, proofs.last().map(|p| p.new_root)) - } - BatchInstruction::Nullify(proofs) => { - let ix = proofs - .iter() - .map(|data| { - Ok(create_batch_nullify_instruction( - self.context.authority.pubkey(), - self.context.derivation, - self.context.merkle_tree, - self.context.epoch, - data.try_to_vec()?, - )) - }) - .collect::>>()?; - (ix, proofs.last().map(|p| p.new_root)) - } - }; - - match send_transaction_batch(&self.context, instructions).await { - Ok(sig) => { - if let Some(root) = expected_root { - self.last_seen_root = root; - } - processed += self.zkp_batch_size as usize; - self.expected_seq += 1; - info!( - "tx sent {} root {:?} seq {} epoch {}", - sig, self.last_seen_root, self.expected_seq, self.context.epoch - ); - } - Err(e) => { - warn!("tx error {} epoch {}", e, self.context.epoch); - return if let Some(ForesterError::NotInActivePhase) = - e.downcast_ref::() - { - warn!("Active phase ended while sending tx, stopping sender loop"); - Ok(processed) - } else { - Err(e) - }; - } - } - } - } - - Ok(processed) - } -} diff --git a/forester/src/processor/v2/strategy/address.rs b/forester/src/processor/v2/strategy/address.rs new file mode 100644 index 0000000000..06e94d5500 --- /dev/null +++ b/forester/src/processor/v2/strategy/address.rs @@ -0,0 +1,382 @@ +use std::sync::Arc; + +use anyhow::anyhow; +use async_trait::async_trait; +use forester_utils::{ + address_staging_tree::{AddressStagingTree, AddressStagingTreeError}, + error::ForesterUtilsError, +}; +use light_batched_merkle_tree::constants::DEFAULT_BATCH_ADDRESS_TREE_HEIGHT; +use light_client::rpc::Rpc; +use light_compressed_account::QueueType; +use light_prover_client::errors::ProverClientError; +use tracing::{debug, info, instrument}; + +use crate::processor::v2::{ + batch_job_builder::BatchJobBuilder, + common::get_leaves_hashchain, + errors::V2Error, + helpers::{ + fetch_address_zkp_batch_size, fetch_onchain_address_root, fetch_streaming_address_batches, + lock_recover, StreamingAddressQueue, + }, + proof_worker::ProofInput, + root_guard::{reconcile_alignment, AlignmentDecision}, + strategy::{CircuitType, QueueData, TreeStrategy}, + BatchContext, +}; + +#[derive(Debug, Clone)] +pub struct AddressTreeStrategy; + +pub struct AddressQueueData { + pub staging_tree: AddressStagingTree, + pub streaming_queue: Arc, + pub data_start_index: u64, + pub zkp_batch_size: usize, +} + +impl AddressQueueData { + pub fn check_alignment(&self) -> Result { + let tree_next = self.staging_tree.next_index() as u64; + let data_start = self.data_start_index; + + if data_start > tree_next { + // Tree is stale - indexer has more elements than we know about + Err(AddressAlignmentError::TreeStale { + tree_next_index: tree_next, + data_start_index: data_start, + }) + } else if data_start == tree_next { + // Perfect alignment + Ok(0) + } else { + // Overlap - we've already processed some elements + let overlap = (tree_next - data_start) as usize; + Ok(overlap) + } + } + + /// Get the batch index to start processing from, accounting for overlap. + /// Returns None if tree is stale. + pub fn first_processable_batch(&self) -> Option { + match self.check_alignment() { + Ok(overlap) => { + let batch_idx = overlap / self.zkp_batch_size; + Some(batch_idx) + } + Err(_) => None, + } + } +} + +#[derive(Debug, Clone)] +pub enum AddressAlignmentError { + TreeStale { + tree_next_index: u64, + data_start_index: u64, + }, +} + +impl std::fmt::Display for AddressAlignmentError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AddressAlignmentError::TreeStale { + tree_next_index, + data_start_index, + } => write!( + f, + "Address staging tree is stale: tree_next_index={}, data_start_index={}", + tree_next_index, data_start_index + ), + } + } +} + +impl std::error::Error for AddressAlignmentError {} + +impl std::fmt::Debug for AddressQueueData { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AddressQueueData") + .field("staging_tree", &self.staging_tree) + .field("data_start_index", &self.data_start_index) + .field( + "available_batches", + &self.streaming_queue.available_batches(), + ) + .field("alignment", &self.check_alignment()) + .finish() + } +} + +#[async_trait] +impl TreeStrategy for AddressTreeStrategy { + type StagingTree = AddressQueueData; + + fn name(&self) -> &'static str { + "Address" + } + + fn circuit_type(&self, _queue_data: &Self::StagingTree) -> CircuitType { + CircuitType::AddressAppend + } + + fn queue_type() -> QueueType { + QueueType::AddressV2 + } + + async fn fetch_zkp_batch_size(&self, context: &BatchContext) -> crate::Result { + fetch_address_zkp_batch_size(context).await + } + + async fn fetch_onchain_root(&self, context: &BatchContext) -> crate::Result<[u8; 32]> { + fetch_onchain_address_root(context).await + } + + #[instrument(level = "debug", skip(self, context), fields(tree = %context.merkle_tree))] + async fn fetch_queue_data( + &self, + context: &BatchContext, + max_batches: usize, + zkp_batch_size: u64, + ) -> crate::Result>> { + let zkp_batch_size_usize = zkp_batch_size as usize; + let total_needed = max_batches.saturating_mul(zkp_batch_size_usize); + let fetch_len = total_needed as u64; + + let streaming_queue = + match fetch_streaming_address_batches(context, fetch_len, zkp_batch_size).await? { + Some(sq) => Arc::new(sq), + None => { + debug!("No address queue data available"); + return Ok(None); + } + }; + + let subtrees = streaming_queue.subtrees(); + if subtrees.is_empty() { + return Err(anyhow!("Address queue missing subtrees data")); + } + + let initial_batches = streaming_queue.available_batches(); + if initial_batches == 0 { + debug!( + zkp_batch_size = zkp_batch_size_usize, + "Not enough addresses for a complete batch" + ); + return Ok(None); + } + + let initial_root = streaming_queue.initial_root(); + let start_index = streaming_queue.start_index(); + + let subtrees_arr: [[u8; 32]; DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as usize] = + subtrees.try_into().map_err(|v: Vec<[u8; 32]>| { + anyhow!( + "Subtrees length mismatch: expected {}, got {}", + DEFAULT_BATCH_ADDRESS_TREE_HEIGHT, + v.len() + ) + })?; + + let staging_tree = tokio::task::spawn_blocking(move || { + let start = std::time::Instant::now(); + let tree = AddressStagingTree::new(subtrees_arr, initial_root, start_index as usize); + info!( + "AddressStagingTree init took {:?}, start_index={}", + start.elapsed(), + start_index + ); + tree + }) + .await + .map_err(|e| anyhow!("spawn_blocking join error: {}", e))??; + + let num_batches = initial_batches.min(max_batches); + + info!( + "Address queue ready: {} batches available, processing {} (streaming in background), start_index={}", + initial_batches, num_batches, start_index + ); + + Ok(Some(QueueData { + staging_tree: AddressQueueData { + staging_tree, + streaming_queue, + data_start_index: start_index, + zkp_batch_size: zkp_batch_size as usize, + }, + initial_root, + num_batches, + })) + } + + fn available_batches(&self, queue_data: &Self::StagingTree, _zkp_batch_size: u64) -> usize { + queue_data.available_batches(_zkp_batch_size) + } +} + +impl BatchJobBuilder for AddressQueueData { + fn build_proof_job( + &mut self, + batch_idx: usize, + zkp_batch_size: u64, + epoch: u64, + tree: &str, + ) -> crate::Result> { + let zkp_batch_size_usize = zkp_batch_size as usize; + let start = batch_idx * zkp_batch_size_usize; + + let tree_next_index = self.staging_tree.next_index(); + let data_start = self.data_start_index as usize; + + match reconcile_alignment(tree_next_index, data_start, start) { + AlignmentDecision::StaleTree => { + return Err(V2Error::StaleTree { + tree_id: tree.to_string(), + details: format!( + "address staging tree is stale: tree_next_index={}, data_start_index={}", + tree_next_index, data_start + ), + } + .into()); + } + AlignmentDecision::SkipOverlap => { + let absolute_index = data_start + start; + tracing::debug!( + "Skipping address batch (overlap): absolute_index={}, tree_next_index={}, batch_size={}", + absolute_index, + tree_next_index, + zkp_batch_size_usize + ); + return Ok(None); + } + AlignmentDecision::Gap => { + let absolute_index = data_start + start; + return Err(V2Error::StaleTree { + tree_id: tree.to_string(), + details: format!( + "address batch gap: absolute_index={} > tree_next_index={} (batch_size={})", + absolute_index, tree_next_index, zkp_batch_size_usize + ), + } + .into()); + } + AlignmentDecision::Process => {} + } + + let batch_end = start + zkp_batch_size_usize; + + let batch_data = self + .streaming_queue + .get_batch_data(start, batch_end) + .ok_or_else(|| { + anyhow!( + "Batch data not available: start={}, end={}, available={}", + start, + batch_end, + self.streaming_queue.available_batches() * zkp_batch_size_usize + ) + })?; + + let addresses = &batch_data.addresses; + let zkp_batch_size_actual = addresses.len(); + + if zkp_batch_size_actual == 0 { + return Err(anyhow!("Empty batch at start={}", start)); + } + + let low_element_values = &batch_data.low_element_values; + let low_element_next_values = &batch_data.low_element_next_values; + let low_element_indices = &batch_data.low_element_indices; + let low_element_next_indices = &batch_data.low_element_next_indices; + + let low_element_proofs: Vec> = { + let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); + (start..start + zkp_batch_size_actual) + .map(|i| data.reconstruct_proof(i, DEFAULT_BATCH_ADDRESS_TREE_HEIGHT as u8)) + .collect::, _>>()? + }; + + let hashchain_idx = start / zkp_batch_size_usize; + let leaves_hashchain = { + let data = lock_recover(self.streaming_queue.data.as_ref(), "streaming_queue.data"); + get_leaves_hashchain(&data.leaves_hash_chains, hashchain_idx)? + }; + + let tree_batch = tree_next_index / zkp_batch_size_usize; + let absolute_index = data_start + start; + + tracing::debug!( + "Address build_proof_job: start={}, absolute_index={}, hashchain_idx={}, batch_size={}, tree_next_index={}, tree_batch={}, streaming_complete={}", + start, + absolute_index, + hashchain_idx, + zkp_batch_size_actual, + tree_next_index, + tree_batch, + self.streaming_queue.is_complete() + ); + + let result = self.staging_tree.process_batch( + addresses, + low_element_values, + low_element_next_values, + low_element_indices, + low_element_next_indices, + &low_element_proofs, + leaves_hashchain, + zkp_batch_size_actual, + epoch, + tree, + ); + + let result = match result { + Ok(r) => r, + Err(err) => return Err(map_address_staging_error(tree, err)), + }; + + Ok(Some(( + ProofInput::AddressAppend(result.circuit_inputs), + result.new_root, + ))) + } + + fn available_batches(&self, _zkp_batch_size: u64) -> usize { + self.streaming_queue.available_batches() + } +} + +fn map_address_staging_error(tree: &str, err: ForesterUtilsError) -> anyhow::Error { + match err { + ForesterUtilsError::AddressStagingTree(AddressStagingTreeError::CircuitInputs { + source: + ProverClientError::HashchainMismatch { + computed, + expected, + batch_size, + next_index, + }, + .. + }) => V2Error::HashchainMismatch { + tree_id: tree.to_string(), + details: format!( + "computed {:?}[..4] != expected {:?}[..4] (batch_size={}, next_index={})", + &computed[..4], + &expected[..4], + batch_size, + next_index + ), + } + .into(), + ForesterUtilsError::AddressStagingTree(AddressStagingTreeError::CircuitInputs { + source: ProverClientError::ProofPatchFailed(details), + .. + }) => V2Error::ProofPatchFailed { + tree_id: tree.to_string(), + details, + } + .into(), + other => anyhow::anyhow!("{}", other), + } +} diff --git a/forester/src/processor/v2/strategy/mod.rs b/forester/src/processor/v2/strategy/mod.rs new file mode 100644 index 0000000000..fdac39b2bf --- /dev/null +++ b/forester/src/processor/v2/strategy/mod.rs @@ -0,0 +1,84 @@ +use async_trait::async_trait; +use light_client::rpc::Rpc; + +use crate::processor::v2::{ + batch_job_builder::BatchJobBuilder, proof_worker::ProofInput, BatchContext, +}; + +mod address; +mod state; + +pub use address::AddressTreeStrategy; +pub use state::StateTreeStrategy; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CircuitType { + Append, + Nullify, + AddressAppend, +} + +#[derive(Debug)] +pub struct QueueData { + pub staging_tree: T, + pub initial_root: [u8; 32], + pub num_batches: usize, +} + +use light_compressed_account::QueueType; + +#[async_trait] +pub trait TreeStrategy: Send + Sync + Clone + std::fmt::Debug + 'static { + type StagingTree: Send + 'static; + + fn name(&self) -> &'static str; + fn circuit_type(&self, queue_data: &Self::StagingTree) -> CircuitType; + + /// Returns the default queue type for this strategy + fn queue_type() -> QueueType; + + fn circuit_type_for_batch( + &self, + queue_data: &Self::StagingTree, + batch_idx: usize, + ) -> CircuitType { + let _ = batch_idx; + self.circuit_type(queue_data) + } + + async fn fetch_zkp_batch_size(&self, context: &BatchContext) -> crate::Result; + + async fn fetch_onchain_root(&self, context: &BatchContext) -> crate::Result<[u8; 32]>; + + async fn fetch_queue_data( + &self, + context: &BatchContext, + max_batches: usize, + zkp_batch_size: u64, + ) -> crate::Result>>; + + /// Build proof job for a batch. Returns: + /// - `Ok(Some((input, root)))` - batch processed, proof job created + /// - `Ok(None)` - batch should be skipped (e.g., overlap with already-processed data) + /// - `Err(...)` - fatal error, stop processing + fn build_proof_job( + &self, + queue_data: &mut Self::StagingTree, + batch_idx: usize, + zkp_batch_size: u64, + epoch: u64, + tree: &str, + ) -> crate::Result> + where + Self::StagingTree: BatchJobBuilder, + { + BatchJobBuilder::build_proof_job(queue_data, batch_idx, zkp_batch_size, epoch, tree) + } + + /// Returns the number of batches currently available in the staging tree. + /// For streaming implementations, this may increase as more data is fetched. + /// Default implementation returns usize::MAX (unlimited). + fn available_batches(&self, _queue_data: &Self::StagingTree, _zkp_batch_size: u64) -> usize { + usize::MAX + } +} diff --git a/forester/src/processor/v2/strategy/state.rs b/forester/src/processor/v2/strategy/state.rs new file mode 100644 index 0000000000..d517c51757 --- /dev/null +++ b/forester/src/processor/v2/strategy/state.rs @@ -0,0 +1,633 @@ +use anyhow::anyhow; +use async_trait::async_trait; +use forester_utils::staging_tree::{BatchType, StagingTree}; +use light_batched_merkle_tree::constants::DEFAULT_BATCH_STATE_TREE_HEIGHT; +use light_client::rpc::Rpc; +use light_prover_client::proof_types::{ + batch_append::BatchAppendsCircuitInputs, batch_update::BatchUpdateCircuitInputs, +}; +use tracing::{debug, instrument}; + +use crate::processor::v2::{ + batch_job_builder::BatchJobBuilder, + common::{batch_range, get_leaves_hashchain}, + helpers::{fetch_onchain_state_root, fetch_paginated_batches, fetch_zkp_batch_size}, + proof_worker::ProofInput, + root_guard::{reconcile_alignment, AlignmentDecision}, + strategy::{CircuitType, QueueData, TreeStrategy}, + BatchContext, +}; + +#[derive(Debug, Clone)] +pub struct StateTreeStrategy; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum StatePhase { + Append, + Nullify, +} + +#[derive(Debug, Clone, Copy)] +pub enum BatchOp { + Append(usize), + Nullify(usize), +} + +#[derive(Debug)] +pub struct StateQueueData { + pub staging_tree: StagingTree, + pub state_queue: light_client::indexer::StateQueueData, + pub phase: StatePhase, + pub next_index: Option, + pub append_batches_before_nullify: usize, + pub interleaved_ops: Option>, + /// First queue index for output queue data (where this batch starts) + pub output_first_queue_index: u64, + /// First queue index for input queue data (where this batch starts) + pub input_first_queue_index: u64, + /// Number of output queue elements processed (for alignment tracking) + pub output_processed: usize, + /// Number of input queue elements processed (for alignment tracking) + pub input_processed: usize, + /// ZKP batch size for alignment calculations + pub zkp_batch_size: usize, +} + +impl StateQueueData { + /// Get number of remaining output batches + pub fn remaining_output_batches(&self) -> usize { + let total_output = self + .state_queue + .output_queue + .as_ref() + .map(|oq| oq.leaf_indices.len()) + .unwrap_or(0); + let remaining = total_output.saturating_sub(self.output_processed); + remaining / self.zkp_batch_size + } + + /// Get number of remaining input batches + pub fn remaining_input_batches(&self) -> usize { + let total_input = self + .state_queue + .input_queue + .as_ref() + .map(|iq| iq.leaf_indices.len()) + .unwrap_or(0); + let remaining = total_input.saturating_sub(self.input_processed); + remaining / self.zkp_batch_size + } + + /// Get expected next output queue index (for alignment validation when re-fetching) + pub fn expected_output_queue_index(&self) -> u64 { + self.output_first_queue_index + self.output_processed as u64 + } + + /// Get expected next input queue index (for alignment validation when re-fetching) + pub fn expected_input_queue_index(&self) -> u64 { + self.input_first_queue_index + self.input_processed as u64 + } +} + +use light_compressed_account::QueueType; + +#[async_trait] +impl TreeStrategy for StateTreeStrategy { + type StagingTree = StateQueueData; + + fn name(&self) -> &'static str { + "State" + } + + fn circuit_type(&self, queue_data: &Self::StagingTree) -> CircuitType { + match queue_data.phase { + StatePhase::Append => CircuitType::Append, + StatePhase::Nullify => CircuitType::Nullify, + } + } + + fn queue_type() -> QueueType { + QueueType::OutputStateV2 + } + + fn circuit_type_for_batch( + &self, + queue_data: &Self::StagingTree, + batch_idx: usize, + ) -> CircuitType { + if let Some(ref ops) = queue_data.interleaved_ops { + if let Some(op) = ops.get(batch_idx) { + return match op { + BatchOp::Append(_) => CircuitType::Append, + BatchOp::Nullify(_) => CircuitType::Nullify, + }; + } + } + + let is_append_phase = batch_idx < queue_data.append_batches_before_nullify + || (queue_data.append_batches_before_nullify == 0 + && queue_data.phase == StatePhase::Append); + + if is_append_phase { + CircuitType::Append + } else { + CircuitType::Nullify + } + } + + async fn fetch_zkp_batch_size(&self, context: &BatchContext) -> crate::Result { + fetch_zkp_batch_size(context).await + } + + async fn fetch_onchain_root(&self, context: &BatchContext) -> crate::Result<[u8; 32]> { + fetch_onchain_state_root(context).await + } + + #[instrument(level = "debug", skip(self, context), fields(tree = %context.merkle_tree))] + async fn fetch_queue_data( + &self, + context: &BatchContext, + max_batches: usize, + zkp_batch_size: u64, + ) -> crate::Result>> { + let zkp_batch_size_usize = zkp_batch_size as usize; + let total_needed = max_batches.saturating_mul(zkp_batch_size_usize); + let fetch_len = total_needed as u64; + + let state_queue = match fetch_paginated_batches(context, fetch_len, zkp_batch_size).await? { + Some(sq) => sq, + None => return Ok(None), + }; + + let initial_root = state_queue.initial_root; + let root_seq = state_queue.root_seq; + let nodes = &state_queue.nodes; + let node_hashes = &state_queue.node_hashes; + + let append_items = state_queue + .output_queue + .as_ref() + .map(|oq| oq.leaf_indices.len()) + .unwrap_or(0); + let nullify_items = state_queue + .input_queue + .as_ref() + .map(|iq| iq.leaf_indices.len()) + .unwrap_or(0); + + debug!( + append_items, + nullify_items, + output_queue = state_queue.output_queue.is_some(), + input_queue = state_queue.input_queue.is_some(), + "Queue data fetched" + ); + + let append_batches = append_items / zkp_batch_size_usize; + let nullify_batches = nullify_items / zkp_batch_size_usize; + + let (append_batches_before_nullify, total_batches, effective_phase) = + if append_batches > 0 && nullify_batches > 0 { + let total = (append_batches + nullify_batches).min(max_batches); + let half_batches = max_batches / 2; + let appends_to_process = append_batches.min(half_batches).max(1); + let nullifies_to_process = + nullify_batches.min(total.saturating_sub(appends_to_process)); + let actual_total = appends_to_process + nullifies_to_process; + debug!( + "Processing {} APPEND batches then {} NULLIFY batches (total: {})", + appends_to_process, nullifies_to_process, actual_total + ); + (appends_to_process, actual_total, StatePhase::Append) + } else if append_batches > 0 { + (0, append_batches.min(max_batches), StatePhase::Append) + } else if nullify_batches > 0 { + (0, nullify_batches.min(max_batches), StatePhase::Nullify) + } else { + return Ok(None); + }; + + let (leaf_indices, leaves, next_index) = + if append_batches_before_nullify > 0 { + let output_batch = state_queue.output_queue.as_ref().ok_or_else(|| { + anyhow!("Expected output_queue batch when processing appends") + })?; + let input_batch = state_queue.input_queue.as_ref().ok_or_else(|| { + anyhow!("Expected input_queue batch when processing nullifies") + })?; + + let mut combined_indices = output_batch.leaf_indices.clone(); + let mut combined_leaves = output_batch.old_leaves.clone(); + + combined_indices.extend(input_batch.leaf_indices.iter().copied()); + combined_leaves.extend(input_batch.current_leaves.iter().copied()); + + ( + combined_indices, + combined_leaves, + Some(output_batch.next_index), + ) + } else { + match effective_phase { + StatePhase::Append => { + let batch = state_queue.output_queue.as_ref().ok_or_else(|| { + anyhow!("Expected output_queue batch when processing appends") + })?; + ( + batch.leaf_indices.clone(), + batch.old_leaves.clone(), + Some(batch.next_index), + ) + } + StatePhase::Nullify => { + let batch = state_queue.input_queue.as_ref().ok_or_else(|| { + anyhow!("Expected input_queue batch when processing nullifies") + })?; + ( + batch.leaf_indices.clone(), + batch.current_leaves.clone(), + None, + ) + } + } + }; + + let nodes = nodes.clone(); + let node_hashes = node_hashes.clone(); + let staging_tree = tokio::task::spawn_blocking(move || { + let start = std::time::Instant::now(); + let tree = StagingTree::new( + &leaf_indices, + &leaves, + &nodes, + &node_hashes, + initial_root, + root_seq, + DEFAULT_BATCH_STATE_TREE_HEIGHT as usize, + ); + debug!( + "StagingTree init took {:?}, leaves={}, nodes={}", + start.elapsed(), + leaf_indices.len(), + nodes.len() + ); + tree + }) + .await + .map_err(|e| anyhow!("spawn_blocking join error: {}", e))??; + + if total_batches == 0 { + return Ok(None); + } + + let interleaved_ops = if append_batches_before_nullify > 0 { + let output_batch = state_queue.output_queue.as_ref().ok_or_else(|| { + anyhow!("Expected output_queue batch when computing interleaving ops") + })?; + let input_batch = state_queue.input_queue.as_ref().ok_or_else(|| { + anyhow!("Expected input_queue batch when computing interleaving ops") + })?; + let initial_next_index = output_batch.next_index; + + let nullifies_to_process = total_batches.saturating_sub(append_batches_before_nullify); + + tracing::info!( + "Interleave check: initial_next_index={}, nullify leaf_indices[0..min(10,len)]={:?}, batch_size={}, num_appends={}, num_nullifies={}", + initial_next_index, + &input_batch.leaf_indices[..input_batch.leaf_indices.len().min(10)], + zkp_batch_size, + append_batches_before_nullify, + nullifies_to_process + ); + + Some(compute_interleaved_ops( + append_batches_before_nullify, + nullifies_to_process, + initial_next_index, + zkp_batch_size, + &input_batch.leaf_indices, + )) + } else { + None + }; + + let interleaved_total = interleaved_ops + .as_ref() + .map(|ops| ops.len()) + .unwrap_or(total_batches); + if let Some(ops) = interleaved_ops.as_ref() { + tracing::info!( + "Interleaved ops: {} total ({} append, {} nullify)", + interleaved_total, + ops.iter() + .filter(|op| matches!(op, BatchOp::Append(_))) + .count(), + ops.iter() + .filter(|op| matches!(op, BatchOp::Nullify(_))) + .count(), + ); + } + + let output_first_queue_index = state_queue + .output_queue + .as_ref() + .map(|oq| oq.first_queue_index) + .unwrap_or(0); + let input_first_queue_index = state_queue + .input_queue + .as_ref() + .map(|iq| iq.first_queue_index) + .unwrap_or(0); + + tracing::info!( + "State queue ready: output_first_queue_index={}, input_first_queue_index={}, batches={}", + output_first_queue_index, input_first_queue_index, interleaved_total + ); + + Ok(Some(QueueData { + staging_tree: StateQueueData { + staging_tree, + state_queue, + phase: effective_phase, + next_index, + append_batches_before_nullify, + interleaved_ops, + output_first_queue_index, + input_first_queue_index, + output_processed: 0, + input_processed: 0, + zkp_batch_size: zkp_batch_size_usize, + }, + initial_root, + num_batches: interleaved_total, + })) + } +} + +impl StateQueueData { + fn build_append_job( + &mut self, + batch_idx: usize, + start: usize, + zkp_batch_size: u64, + epoch: u64, + tree: &str, + ) -> crate::Result> { + let zkp_batch_size_usize = zkp_batch_size as usize; + let expected_queue_index = self.output_first_queue_index as usize + self.output_processed; + let data_start_index = self.output_first_queue_index as usize; + + match reconcile_alignment(expected_queue_index, data_start_index, start) { + AlignmentDecision::SkipOverlap => { + let absolute_queue_index = data_start_index + start; + tracing::debug!( + "Skipping output queue batch (overlap): absolute_index={}, expected_start={}", + absolute_queue_index, + expected_queue_index + ); + return Ok(None); + } + AlignmentDecision::Gap | AlignmentDecision::StaleTree => { + let absolute_queue_index = data_start_index + start; + return Err(anyhow!( + "Output queue stale: expected start {}, got {}. Need to invalidate cache.", + expected_queue_index, + absolute_queue_index + )); + } + AlignmentDecision::Process => {} + } + + let batch = self + .state_queue + .output_queue + .as_ref() + .ok_or_else(|| anyhow!("Output queue not present"))?; + + let range = batch_range(zkp_batch_size, batch.account_hashes.len(), start); + let leaves = &batch.account_hashes[range.clone()]; + let leaf_indices = &batch.leaf_indices[range]; + + let hashchain_idx = start / zkp_batch_size_usize; + let batch_seq = self.state_queue.root_seq + (batch_idx as u64) + 1; + + let result = self.staging_tree.process_batch_updates( + leaf_indices, + leaves, + BatchType::Append, + batch_idx, + batch_seq, + epoch, + tree, + )?; + + self.output_processed += zkp_batch_size_usize; + + let new_root = result.new_root; + let leaves_hashchain = get_leaves_hashchain(&batch.leaves_hash_chains, hashchain_idx)?; + let start_index = leaf_indices.first().copied().unwrap_or(0) as u32; + + let circuit_inputs = + BatchAppendsCircuitInputs::new::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>( + result.into(), + start_index, + leaves, + leaves_hashchain, + zkp_batch_size as u32, + ) + .map_err(|e| anyhow!("Failed to build append inputs: {}", e))?; + + Ok(Some((ProofInput::Append(circuit_inputs), new_root))) + } + + fn build_nullify_job( + &mut self, + batch_idx: usize, + start: usize, + zkp_batch_size: u64, + epoch: u64, + tree: &str, + ) -> crate::Result> { + let zkp_batch_size_usize = zkp_batch_size as usize; + let expected_queue_index = self.input_first_queue_index as usize + self.input_processed; + let data_start_index = self.input_first_queue_index as usize; + + match reconcile_alignment(expected_queue_index, data_start_index, start) { + AlignmentDecision::SkipOverlap => { + let absolute_queue_index = data_start_index + start; + tracing::debug!( + "Skipping input queue batch (overlap): absolute_index={}, expected_start={}", + absolute_queue_index, + expected_queue_index + ); + return Ok(None); + } + AlignmentDecision::Gap | AlignmentDecision::StaleTree => { + let absolute_queue_index = data_start_index + start; + return Err(anyhow!( + "Input queue stale: expected start {}, got {}. Need to invalidate cache.", + expected_queue_index, + absolute_queue_index + )); + } + AlignmentDecision::Process => {} + } + + let batch = self + .state_queue + .input_queue + .as_ref() + .ok_or_else(|| anyhow!("Input queue not present"))?; + + let range = batch_range(zkp_batch_size, batch.account_hashes.len(), start); + let account_hashes = &batch.account_hashes[range.clone()]; + let tx_hashes = &batch.tx_hashes[range.clone()]; + let nullifiers = &batch.nullifiers[range.clone()]; + let leaf_indices = &batch.leaf_indices[range]; + + let hashchain_idx = start / zkp_batch_size_usize; + let batch_seq = self.state_queue.root_seq + (batch_idx as u64) + 1; + + let result = self.staging_tree.process_batch_updates( + leaf_indices, + nullifiers, + BatchType::Nullify, + batch_idx, + batch_seq, + epoch, + tree, + )?; + + self.input_processed += zkp_batch_size_usize; + + let new_root = result.new_root; + let leaves_hashchain = get_leaves_hashchain(&batch.leaves_hash_chains, hashchain_idx)?; + let path_indices: Vec = leaf_indices.iter().map(|idx| *idx as u32).collect(); + + let circuit_inputs = + BatchUpdateCircuitInputs::new::<{ DEFAULT_BATCH_STATE_TREE_HEIGHT as usize }>( + result.into(), + tx_hashes, + account_hashes, + leaves_hashchain, + &path_indices, + zkp_batch_size as u32, + ) + .map_err(|e| anyhow!("Failed to build nullify inputs: {}", e))?; + + Ok(Some((ProofInput::Nullify(circuit_inputs), new_root))) + } +} + +impl BatchJobBuilder for StateQueueData { + fn build_proof_job( + &mut self, + batch_idx: usize, + zkp_batch_size: u64, + epoch: u64, + tree: &str, + ) -> crate::Result> { + if let Some(ref ops) = self.interleaved_ops { + if let Some(op) = ops.get(batch_idx) { + return match op { + BatchOp::Append(append_idx) => { + let start = append_idx * zkp_batch_size as usize; + self.build_append_job(*append_idx, start, zkp_batch_size, epoch, tree) + } + BatchOp::Nullify(nullify_idx) => { + let start = nullify_idx * zkp_batch_size as usize; + self.build_nullify_job(*nullify_idx, start, zkp_batch_size, epoch, tree) + } + }; + } + } + + let is_append_phase = batch_idx < self.append_batches_before_nullify + || (self.append_batches_before_nullify == 0 && self.phase == StatePhase::Append); + + if is_append_phase { + let start = batch_idx * zkp_batch_size as usize; + self.build_append_job(batch_idx, start, zkp_batch_size, epoch, tree) + } else { + let nullify_batch_idx = batch_idx - self.append_batches_before_nullify; + let start = nullify_batch_idx * zkp_batch_size as usize; + self.build_nullify_job(nullify_batch_idx, start, zkp_batch_size, epoch, tree) + } + } +} + +fn compute_interleaved_ops( + num_appends: usize, + num_nullifies: usize, + initial_next_index: u64, + batch_size: u64, + nullify_leaf_indices: &[u64], +) -> Vec { + let batch_size_usize = batch_size as usize; + let mut ops = Vec::with_capacity(num_appends + num_nullifies); + + let mut appends_scheduled = 0usize; + let mut nullifies_scheduled = 0usize; + + let nullify_batch_max_indices: Vec = (0..num_nullifies) + .map(|batch_idx| { + let start = batch_idx * batch_size_usize; + let end = ((batch_idx + 1) * batch_size_usize).min(nullify_leaf_indices.len()); + nullify_leaf_indices[start..end] + .iter() + .copied() + .max() + .unwrap_or(0) + }) + .collect(); + + if !nullify_batch_max_indices.is_empty() { + tracing::info!( + "compute_interleaved_ops: nullify_batch_max_indices[0..min(5,len)]={:?}", + &nullify_batch_max_indices[..nullify_batch_max_indices.len().min(5)] + ); + } + + while appends_scheduled < num_appends || nullifies_scheduled < num_nullifies { + if appends_scheduled < num_appends { + ops.push(BatchOp::Append(appends_scheduled)); + appends_scheduled += 1; + } + + let boundary = initial_next_index + (appends_scheduled as u64 * batch_size); + let mut scheduled_this_round = 0; + while nullifies_scheduled < num_nullifies { + let max_leaf_idx = nullify_batch_max_indices[nullifies_scheduled]; + if max_leaf_idx < boundary { + ops.push(BatchOp::Nullify(nullifies_scheduled)); + nullifies_scheduled += 1; + scheduled_this_round += 1; + } else { + if nullifies_scheduled == 0 && appends_scheduled <= 2 { + tracing::info!( + "Nullify batch {} skipped: max_leaf_idx={} >= boundary={} (initial_next_index={}, appends_scheduled={})", + nullifies_scheduled, max_leaf_idx, boundary, initial_next_index, appends_scheduled + ); + } + break; + } + } + if scheduled_this_round > 0 && appends_scheduled <= 2 { + tracing::info!( + "After append {}: scheduled {} nullifies (boundary={})", + appends_scheduled - 1, + scheduled_this_round, + boundary + ); + } + + if appends_scheduled >= num_appends && nullifies_scheduled < num_nullifies { + while nullifies_scheduled < num_nullifies { + ops.push(BatchOp::Nullify(nullifies_scheduled)); + nullifies_scheduled += 1; + } + } + } + + ops +} diff --git a/forester/src/processor/v2/tx_sender.rs b/forester/src/processor/v2/tx_sender.rs new file mode 100644 index 0000000000..22c634cde5 --- /dev/null +++ b/forester/src/processor/v2/tx_sender.rs @@ -0,0 +1,552 @@ +use std::{ + sync::{atomic::Ordering, Arc}, + time::Duration, +}; + +use borsh::BorshSerialize; + +const MAX_BUFFER_SIZE: usize = 1000; +const V2_IXS_PER_TX_WITH_LUT: usize = 5; +const V2_IXS_PER_TX_WITHOUT_LUT: usize = 4; +const FLUSH_MARGIN_SLOTS: u64 = 2; + +use light_batched_merkle_tree::merkle_tree::{ + InstructionDataBatchAppendInputs, InstructionDataBatchNullifyInputs, +}; +use light_client::rpc::Rpc; +use light_registry::account_compression_cpi::sdk::{ + create_batch_append_instruction, create_batch_nullify_instruction, + create_batch_update_address_tree_instruction, +}; +use solana_sdk::{instruction::Instruction, signature::Signer}; +use tokio::{sync::mpsc, task::JoinHandle}; +use tracing::{debug, info, warn}; + +use crate::{ + errors::ForesterError, + processor::v2::{ + common::send_transaction_batch, proof_cache::SharedProofCache, + proof_worker::ProofJobResult, BatchContext, + }, +}; + +#[derive(Debug, Clone, Default)] +pub struct ProofTimings { + pub append_proof_ms: u64, + pub nullify_proof_ms: u64, + pub address_append_proof_ms: u64, + pub append_round_trip_ms: u64, + pub nullify_round_trip_ms: u64, + pub address_append_round_trip_ms: u64, +} + +impl ProofTimings { + pub fn append_proof_duration(&self) -> Duration { + Duration::from_millis(self.append_proof_ms) + } + pub fn nullify_proof_duration(&self) -> Duration { + Duration::from_millis(self.nullify_proof_ms) + } + pub fn address_append_proof_duration(&self) -> Duration { + Duration::from_millis(self.address_append_proof_ms) + } + pub fn append_round_trip_duration(&self) -> Duration { + Duration::from_millis(self.append_round_trip_ms) + } + pub fn nullify_round_trip_duration(&self) -> Duration { + Duration::from_millis(self.nullify_round_trip_ms) + } + pub fn address_append_round_trip_duration(&self) -> Duration { + Duration::from_millis(self.address_append_round_trip_ms) + } +} + +/// Result of TxSender processing +#[derive(Debug, Clone, Default)] +pub struct TxSenderResult { + pub items_processed: usize, + pub proof_timings: ProofTimings, + /// Number of proofs saved to cache when epoch ended (for potential reuse) + pub proofs_saved_to_cache: usize, + /// Total time spent sending transactions + pub tx_sending_duration: Duration, +} + +#[derive(Debug, Clone)] +pub enum BatchInstruction { + Append(Vec), + Nullify(Vec), + AddressAppend(Vec), +} + +impl BatchInstruction { + /// Returns the number of ZKP batch instructions contained in this batch. + pub fn items_count(&self) -> usize { + match self { + BatchInstruction::Append(v) => v.len(), + BatchInstruction::Nullify(v) => v.len(), + BatchInstruction::AddressAppend(v) => v.len(), + } + } +} + +/// Entry in the ordered proof buffer: instruction + timing info +#[derive(Clone)] +struct BufferEntry { + instruction: BatchInstruction, + round_trip_ms: u64, + proof_ms: u64, + submitted_at: std::time::Instant, +} + +struct OrderedProofBuffer { + buffer: Vec>, + base_seq: u64, + len: usize, + head: usize, +} + +impl OrderedProofBuffer { + fn new(capacity: usize) -> Self { + Self { + buffer: (0..capacity).map(|_| None).collect(), + base_seq: 0, + len: 0, + head: 0, + } + } + + fn capacity(&self) -> usize { + self.buffer.len() + } + + fn len(&self) -> usize { + self.len + } + + fn insert( + &mut self, + seq: u64, + instruction: BatchInstruction, + round_trip_ms: u64, + proof_ms: u64, + submitted_at: std::time::Instant, + ) -> bool { + if seq < self.base_seq { + return false; + } + let offset = (seq - self.base_seq) as usize; + if offset >= self.buffer.len() { + return false; + } + let idx = (self.head + offset) % self.buffer.len(); + if self.buffer[idx].is_none() { + self.len += 1; + } + self.buffer[idx] = Some(BufferEntry { + instruction, + round_trip_ms, + proof_ms, + submitted_at, + }); + true + } + + fn pop_next(&mut self) -> Option { + let item = self.buffer[self.head].take(); + if item.is_some() { + self.len -= 1; + self.base_seq += 1; + self.head = (self.head + 1) % self.buffer.len(); + } + item + } + + fn expected_seq(&self) -> u64 { + self.base_seq + } +} + +pub struct TxSender { + context: BatchContext, + buffer: OrderedProofBuffer, + zkp_batch_size: u64, + last_seen_root: [u8; 32], + pending_batch: Vec<(BatchInstruction, u64)>, // (instruction, seq) + pending_batch_round_trip_ms: u64, + pending_batch_proof_ms: u64, + /// Earliest submission time in the pending batch (for end-to-end latency) + pending_batch_earliest_submit: Option, + proof_timings: ProofTimings, + /// Optional cache to save unused proofs when epoch ends (for reuse in next epoch) + proof_cache: Option>, + /// Maximum instructions per transaction + ixs_per_tx: usize, +} + +impl TxSender { + pub(crate) fn spawn( + context: BatchContext, + proof_rx: mpsc::Receiver, + zkp_batch_size: u64, + last_seen_root: [u8; 32], + proof_cache: Option>, + ) -> JoinHandle> { + let ixs_per_tx = if context.address_lookup_tables.is_empty() { + V2_IXS_PER_TX_WITHOUT_LUT + } else { + V2_IXS_PER_TX_WITH_LUT + }; + + let sender = Self { + context, + buffer: OrderedProofBuffer::new(MAX_BUFFER_SIZE), + zkp_batch_size, + last_seen_root, + pending_batch: Vec::with_capacity(ixs_per_tx), + pending_batch_round_trip_ms: 0, + pending_batch_proof_ms: 0, + pending_batch_earliest_submit: None, + proof_timings: ProofTimings::default(), + proof_cache, + ixs_per_tx, + }; + + tokio::spawn(async move { sender.run(proof_rx).await }) + } + + #[inline] + fn eligibility_end_slot(&self) -> u64 { + let forester_end = self + .context + .forester_eligibility_end_slot + .load(Ordering::Relaxed); + if forester_end > 0 { + forester_end + } else { + self.context.epoch_phases.active.end + } + } + + #[inline] + fn should_flush_due_to_time_at(&self, current_slot: u64) -> bool { + let slots_remaining = self.eligibility_end_slot().saturating_sub(current_slot); + slots_remaining <= FLUSH_MARGIN_SLOTS + } + + #[inline] + fn is_still_eligible_at(&self, current_slot: u64) -> bool { + current_slot < self.eligibility_end_slot() + } + + async fn run( + mut self, + mut proof_rx: mpsc::Receiver, + ) -> crate::Result { + let (batch_tx, mut batch_rx) = mpsc::unbounded_channel::<( + Vec<(BatchInstruction, u64)>, + u64, + Option, + )>(); + + let sender_context = self.context.clone(); + let mut sender_last_root = self.last_seen_root; + let zkp_batch_size_val = self.zkp_batch_size; + + let sender_handle = tokio::spawn(async move { + let mut sender_processed = 0usize; + let mut total_tx_sending_duration = Duration::ZERO; + while let Some((batch, _batch_round_trip, batch_earliest_submit)) = + batch_rx.recv().await + { + let items_count: usize = batch.iter().map(|(instr, _)| instr.items_count()).sum(); + let first_seq = batch.first().map(|(_, s)| *s).unwrap_or(0); + let last_seq = batch.last().map(|(_, s)| *s).unwrap_or(0); + + let mut all_instructions: Vec = Vec::new(); + let mut last_root: Option<[u8; 32]> = None; + let mut append_count = 0usize; + let mut nullify_count = 0usize; + let mut _address_append_count = 0usize; + + for (instr, _seq) in &batch { + let res = match instr { + BatchInstruction::Append(proofs) => { + append_count += 1; + let ix_res = proofs + .iter() + .map(|data| { + Ok(create_batch_append_instruction( + sender_context.authority.pubkey(), + sender_context.derivation, + sender_context.merkle_tree, + sender_context.output_queue, + sender_context.epoch, + data.try_to_vec()?, + )) + }) + .collect::>>()?; + (ix_res, proofs.last().map(|p| p.new_root)) + } + BatchInstruction::Nullify(proofs) => { + nullify_count += 1; + let ix_res = proofs + .iter() + .map(|data| { + Ok(create_batch_nullify_instruction( + sender_context.authority.pubkey(), + sender_context.derivation, + sender_context.merkle_tree, + sender_context.epoch, + data.try_to_vec()?, + )) + }) + .collect::>>()?; + (ix_res, proofs.last().map(|p| p.new_root)) + } + BatchInstruction::AddressAppend(proofs) => { + _address_append_count += 1; + let ix_res = proofs + .iter() + .map(|data| { + Ok(create_batch_update_address_tree_instruction( + sender_context.authority.pubkey(), + sender_context.derivation, + sender_context.merkle_tree, + sender_context.epoch, + data.try_to_vec()?, + )) + }) + .collect::>>()?; + (ix_res, proofs.last().map(|p| p.new_root)) + } + }; + all_instructions.extend(res.0); + if let Some(root) = res.1 { + last_root = Some(root); + } + } + + let instr_type = if append_count > 0 && nullify_count > 0 { + format!("Append+Nullify({}+{})", append_count, nullify_count) + } else if append_count > 0 { + "Append".to_string() + } else if nullify_count > 0 { + "Nullify".to_string() + } else { + "AddressAppend".to_string() + }; + + let send_start = std::time::Instant::now(); + match send_transaction_batch(&sender_context, all_instructions).await { + Ok(sig) => { + total_tx_sending_duration += send_start.elapsed(); + if let Some(root) = last_root { + sender_last_root = root; + } + let items_processed = items_count * zkp_batch_size_val as usize; + sender_processed += items_processed; + let e2e_ms = batch_earliest_submit + .map(|t| t.elapsed().as_millis() as u64) + .unwrap_or(0); + info!( + "tx sent: {} type={} ixs={} tree={} root={:?} seq={}..{} epoch={} e2e={}ms", + sig, + instr_type, + items_count, + sender_context.merkle_tree, + &sender_last_root[..4], + first_seq, + last_seq, + sender_context.epoch, + e2e_ms, + ); + } + Err(e) => { + total_tx_sending_duration += send_start.elapsed(); + warn!("tx error {} epoch {}", e, sender_context.epoch); + if let Some(ForesterError::NotInActivePhase) = + e.downcast_ref::() + { + warn!("Active phase ended while sending tx, stopping sender loop"); + return Ok::<_, anyhow::Error>(( + sender_processed, + total_tx_sending_duration, + )); + } else { + return Err(e); + } + } + } + } + Ok((sender_processed, total_tx_sending_duration)) + }); + + loop { + if sender_handle.is_finished() { + break; + } + + let result = match proof_rx.recv().await { + Some(r) => r, + None => break, + }; + + let current_slot = self.context.slot_tracker.estimated_current_slot(); + + if !self.is_still_eligible_at(current_slot) { + let proofs_saved = self.save_proofs_to_cache(&mut proof_rx, Some(result)).await; + info!( + "Active phase ended for epoch {}, stopping tx sender (saved {} proofs to cache)", + self.context.epoch, proofs_saved + ); + drop(batch_tx); + let (items_processed, tx_sending_duration) = sender_handle + .await + .map_err(|e| anyhow::anyhow!("Sender panic: {}", e))??; + return Ok(TxSenderResult { + items_processed, + proof_timings: self.proof_timings, + proofs_saved_to_cache: proofs_saved, + tx_sending_duration, + }); + } + + if let Ok(instr) = &result.result { + match instr { + BatchInstruction::Append(_) => { + self.proof_timings.append_proof_ms += result.proof_duration_ms; + self.proof_timings.append_round_trip_ms += result.round_trip_ms; + } + BatchInstruction::Nullify(_) => { + self.proof_timings.nullify_proof_ms += result.proof_duration_ms; + self.proof_timings.nullify_round_trip_ms += result.round_trip_ms; + } + BatchInstruction::AddressAppend(_) => { + self.proof_timings.address_append_proof_ms += result.proof_duration_ms; + self.proof_timings.address_append_round_trip_ms += result.round_trip_ms; + } + } + } + + let instruction = match result.result { + Ok(instr) => instr, + Err(e) => { + warn!("Proof failed seq={}: {}", result.seq, e); + return Err(anyhow::anyhow!("Proof failed seq={}: {}", result.seq, e)); + } + }; + + if self.buffer.len() >= self.buffer.capacity() { + return Err(anyhow::anyhow!("Proof buffer overflow")); + } + if !self.buffer.insert( + result.seq, + instruction, + result.round_trip_ms, + result.proof_duration_ms, + result.submitted_at, + ) { + warn!("Failed to insert proof seq={}", result.seq); + } + + while let Some(entry) = self.buffer.pop_next() { + let seq = self.buffer.expected_seq() - 1; + self.pending_batch.push((entry.instruction, seq)); + self.pending_batch_round_trip_ms += entry.round_trip_ms; + self.pending_batch_proof_ms += entry.proof_ms; + self.pending_batch_earliest_submit = + Some(match self.pending_batch_earliest_submit { + None => entry.submitted_at, + Some(existing) => existing.min(entry.submitted_at), + }); + + let should_send = self.pending_batch.len() >= self.ixs_per_tx + || (!self.pending_batch.is_empty() + && self.should_flush_due_to_time_at(current_slot)); + + if should_send { + let batch = std::mem::replace( + &mut self.pending_batch, + Vec::with_capacity(self.ixs_per_tx), + ); + let round_trip = std::mem::replace(&mut self.pending_batch_round_trip_ms, 0); + let _proof_ms = std::mem::replace(&mut self.pending_batch_proof_ms, 0); + let earliest = self.pending_batch_earliest_submit.take(); + + if batch_tx.send((batch, round_trip, earliest)).is_err() { + break; + } + } + } + } + + if !self.pending_batch.is_empty() { + let batch = + std::mem::replace(&mut self.pending_batch, Vec::with_capacity(self.ixs_per_tx)); + let round_trip = std::mem::replace(&mut self.pending_batch_round_trip_ms, 0); + let earliest = self.pending_batch_earliest_submit.take(); + let _ = batch_tx.send((batch, round_trip, earliest)); + } + + drop(batch_tx); + let (items_processed, tx_sending_duration) = sender_handle + .await + .map_err(|e| anyhow::anyhow!("Sender panic: {}", e))??; + + Ok(TxSenderResult { + items_processed, + proof_timings: self.proof_timings, + proofs_saved_to_cache: 0, + tx_sending_duration, + }) + } + + async fn save_proofs_to_cache( + &self, + proof_rx: &mut mpsc::Receiver, + current_result: Option, + ) -> usize { + let cache = match &self.proof_cache { + Some(c) => c, + None => { + debug!("No proof cache available, discarding remaining proofs"); + return 0; + } + }; + + let mut saved = 0; + + cache.start_warming(self.last_seen_root).await; + + if let Some(result) = current_result { + if let Ok(instruction) = result.result { + cache + .add_proof(result.seq, result.old_root, result.new_root, instruction) + .await; + saved += 1; + } + } + + while let Ok(result) = proof_rx.try_recv() { + if let Ok(instruction) = result.result { + cache + .add_proof(result.seq, result.old_root, result.new_root, instruction) + .await; + saved += 1; + } + } + + cache.finish_warming().await; + + if saved > 0 { + info!( + "Saved {} proofs to cache for potential reuse (root: {:?})", + saved, + &self.last_seen_root[..4] + ); + } + + saved + } +} diff --git a/forester/src/queue_helpers.rs b/forester/src/queue_helpers.rs index f2c4ee395f..aad6b5984a 100644 --- a/forester/src/queue_helpers.rs +++ b/forester/src/queue_helpers.rs @@ -36,8 +36,17 @@ pub async fn fetch_queue_item_data( return Ok(Vec::new()); } }; - let queue: HashSet = - unsafe { HashSet::from_bytes_copy(&mut account.data[8 + size_of::()..])? }; + let offset = 8 + std::mem::size_of::(); + if account.data.len() < offset { + tracing::warn!( + "Queue account {} data too short ({} < {})", + queue_pubkey, + account.data.len(), + offset + ); + return Ok(Vec::new()); + } + let queue: HashSet = unsafe { HashSet::from_bytes_copy(&mut account.data[offset..])? }; let end_index = (start_index + processing_length).min(queue_length); let filtered_queue = queue diff --git a/forester/src/rollover/operations.rs b/forester/src/rollover/operations.rs index a461dcd0a4..cfab6c1210 100644 --- a/forester/src/rollover/operations.rs +++ b/forester/src/rollover/operations.rs @@ -2,6 +2,7 @@ use account_compression::{ AddressMerkleTreeAccount, AddressMerkleTreeConfig, AddressQueueConfig, NullifierQueueConfig, QueueAccount, StateMerkleTreeAccount, StateMerkleTreeConfig, }; +use anyhow::Context; use forester_utils::{ account_zero_copy::{get_concurrent_merkle_tree, get_indexed_merkle_tree}, address_merkle_tree_config::{get_address_bundle_config, get_state_bundle_config}, @@ -11,7 +12,7 @@ use forester_utils::{ use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; use light_client::{ indexer::{AddressMerkleTreeAccounts, StateMerkleTreeAccounts}, - rpc::{Rpc, RpcError}, + rpc::Rpc, }; use light_compressed_account::TreeType; use light_hasher::Poseidon; @@ -52,7 +53,12 @@ pub async fn get_tree_fullness( let account = rpc .get_anchor_account::(&tree_pubkey) .await? - .unwrap(); + .ok_or_else(|| { + ForesterError::Other(anyhow::anyhow!( + "StateV1 merkle tree account not found: {}", + tree_pubkey + )) + })?; let merkle_tree = get_concurrent_merkle_tree::( @@ -78,11 +84,21 @@ pub async fn get_tree_fullness( let account = rpc .get_anchor_account::(&tree_pubkey) .await? - .unwrap(); + .ok_or_else(|| { + ForesterError::Other(anyhow::anyhow!( + "AddressV1 merkle tree account not found: {}", + tree_pubkey + )) + })?; let queue_account = rpc .get_anchor_account::(&account.metadata.associated_queue.into()) .await? - .unwrap(); + .ok_or_else(|| { + ForesterError::Other(anyhow::anyhow!( + "AddressV1 queue account not found: {:?}", + account.metadata.associated_queue + )) + })?; let merkle_tree = get_indexed_merkle_tree::( rpc, @@ -105,10 +121,12 @@ pub async fn get_tree_fullness( }) } TreeType::StateV2 => { - let mut account = rpc.get_account(tree_pubkey).await?.unwrap(); + let mut account = rpc.get_account(tree_pubkey).await?.ok_or_else(|| { + anyhow::anyhow!("StateV2 tree account not found: {}", tree_pubkey) + })?; let merkle_tree = BatchedMerkleTreeAccount::state_from_bytes(&mut account.data, &tree_pubkey.into()) - .unwrap(); + .map_err(|e| anyhow::anyhow!("Failed to parse StateV2 tree: {:?}", e))?; let height = merkle_tree.height as u64; let capacity = 1u64 << height; @@ -126,12 +144,14 @@ pub async fn get_tree_fullness( } TreeType::AddressV2 => { - let mut account = rpc.get_account(tree_pubkey).await?.unwrap(); + let mut account = rpc.get_account(tree_pubkey).await?.ok_or_else(|| { + anyhow::anyhow!("AddressV2 tree account not found: {}", tree_pubkey) + })?; let merkle_tree = BatchedMerkleTreeAccount::address_from_bytes( &mut account.data, &tree_pubkey.into(), ) - .unwrap(); + .map_err(|e| anyhow::anyhow!("Failed to parse AddressV2 tree: {:?}", e))?; let height = merkle_tree.height as u64; let capacity = 1u64 << height; @@ -172,12 +192,22 @@ pub async fn is_tree_ready_for_rollover( TreeType::StateV1 => TreeAccount::State( rpc.get_anchor_account::(&tree_pubkey) .await? - .unwrap(), + .ok_or_else(|| { + ForesterError::Other(anyhow::anyhow!( + "StateV1 merkle tree account not found: {}", + tree_pubkey + )) + })?, ), TreeType::AddressV1 => TreeAccount::Address( rpc.get_anchor_account::(&tree_pubkey) .await? - .unwrap(), + .ok_or_else(|| { + ForesterError::Other(anyhow::anyhow!( + "AddressV1 merkle tree account not found: {}", + tree_pubkey + )) + })?, ), _ => return Err(ForesterError::InvalidTreeType(tree_type)), }; @@ -217,7 +247,7 @@ pub async fn perform_state_merkle_tree_rollover_forester( old_queue_pubkey: &Pubkey, old_cpi_context_pubkey: &Pubkey, epoch: u64, -) -> Result { +) -> Result { let instructions = create_rollover_state_merkle_tree_instructions( context, &payer.pubkey(), @@ -230,7 +260,7 @@ pub async fn perform_state_merkle_tree_rollover_forester( old_cpi_context_pubkey, epoch, ) - .await; + .await?; let blockhash = context.get_latest_blockhash().await?; let transaction = Transaction::new_signed_with_payer( &instructions, @@ -243,7 +273,10 @@ pub async fn perform_state_merkle_tree_rollover_forester( ], blockhash.0, ); - context.process_transaction(transaction).await + context + .process_transaction(transaction) + .await + .map_err(Into::into) } #[allow(clippy::too_many_arguments)] @@ -256,7 +289,7 @@ pub async fn perform_address_merkle_tree_rollover( old_merkle_tree_pubkey: &Pubkey, old_queue_pubkey: &Pubkey, epoch: u64, -) -> Result { +) -> Result { let mut instructions = create_rollover_address_merkle_tree_instructions( context, &payer.pubkey(), @@ -267,7 +300,7 @@ pub async fn perform_address_merkle_tree_rollover( old_queue_pubkey, epoch, ) - .await; + .await?; let compute_budget_instruction = ComputeBudgetInstruction::set_compute_unit_limit(500_000); instructions.insert(0, compute_budget_instruction); let blockhash = context.get_latest_blockhash().await?; @@ -277,7 +310,10 @@ pub async fn perform_address_merkle_tree_rollover( &vec![&payer, &new_queue_keypair, &new_address_merkle_tree_keypair], blockhash.0, ); - context.process_transaction(transaction).await + context + .process_transaction(transaction) + .await + .map_err(Into::into) } #[allow(clippy::too_many_arguments)] @@ -290,7 +326,7 @@ pub async fn create_rollover_address_merkle_tree_instructions( merkle_tree_pubkey: &Pubkey, nullifier_queue_pubkey: &Pubkey, epoch: u64, -) -> Vec { +) -> Result, ForesterError> { let (merkle_tree_config, queue_config) = get_address_bundle_config( rpc, AddressMerkleTreeAccounts { @@ -305,7 +341,7 @@ pub async fn create_rollover_address_merkle_tree_instructions( &merkle_tree_config, &queue_config, ) - .await; + .await?; let create_nullifier_queue_instruction = create_account_instruction( authority, queue_rent_exemption.size, @@ -334,11 +370,11 @@ pub async fn create_rollover_address_merkle_tree_instructions( }, epoch, ); - vec![ + Ok(vec![ create_nullifier_queue_instruction, create_state_merkle_tree_instruction, instruction, - ] + ]) } #[allow(clippy::too_many_arguments)] @@ -353,7 +389,7 @@ pub async fn create_rollover_state_merkle_tree_instructions( nullifier_queue_pubkey: &Pubkey, old_cpi_context_pubkey: &Pubkey, epoch: u64, -) -> Vec { +) -> Result, ForesterError> { let (merkle_tree_config, queue_config) = get_state_bundle_config( rpc, StateMerkleTreeAccounts { @@ -366,7 +402,7 @@ pub async fn create_rollover_state_merkle_tree_instructions( .await; let (state_merkle_tree_rent_exemption, queue_rent_exemption) = get_rent_exemption_for_state_merkle_tree_and_queue(rpc, &merkle_tree_config, &queue_config) - .await; + .await?; let create_nullifier_queue_instruction = create_account_instruction( authority, queue_rent_exemption.size, @@ -385,7 +421,7 @@ pub async fn create_rollover_state_merkle_tree_instructions( let rent_cpi_config = rpc .get_minimum_balance_for_rent_exemption(ProtocolConfig::default().cpi_context_size as usize) .await - .unwrap(); + .context("Failed to fetch rent exemption for CPI context")?; let create_cpi_context_instruction = create_account_instruction( authority, ProtocolConfig::default().cpi_context_size as usize, @@ -407,25 +443,26 @@ pub async fn create_rollover_state_merkle_tree_instructions( }, epoch, ); - vec![ + Ok(vec![ create_cpi_context_instruction, create_nullifier_queue_instruction, create_state_merkle_tree_instruction, instruction, - ] + ]) } pub async fn get_rent_exemption_for_state_merkle_tree_and_queue( rpc: &mut R, merkle_tree_config: &StateMerkleTreeConfig, queue_config: &NullifierQueueConfig, -) -> (RentExemption, RentExemption) { - let queue_size = QueueAccount::size(queue_config.capacity as usize).unwrap(); +) -> Result<(RentExemption, RentExemption), ForesterError> { + let queue_size = QueueAccount::size(queue_config.capacity as usize) + .context("Failed to compute StateV1 queue account size")?; let queue_rent_exempt_lamports = rpc .get_minimum_balance_for_rent_exemption(queue_size) .await - .unwrap(); + .context("Failed to fetch rent exemption for StateV1 queue account")?; let tree_size = StateMerkleTreeAccount::size( merkle_tree_config.height as usize, merkle_tree_config.changelog_size as usize, @@ -435,8 +472,8 @@ pub async fn get_rent_exemption_for_state_merkle_tree_and_queue( let merkle_tree_rent_exempt_lamports = rpc .get_minimum_balance_for_rent_exemption(tree_size) .await - .unwrap(); - ( + .context("Failed to fetch rent exemption for StateV1 merkle tree account")?; + Ok(( RentExemption { lamports: merkle_tree_rent_exempt_lamports, size: tree_size, @@ -445,20 +482,21 @@ pub async fn get_rent_exemption_for_state_merkle_tree_and_queue( lamports: queue_rent_exempt_lamports, size: queue_size, }, - ) + )) } pub async fn get_rent_exemption_for_address_merkle_tree_and_queue( rpc: &mut R, address_merkle_tree_config: &AddressMerkleTreeConfig, address_queue_config: &AddressQueueConfig, -) -> (RentExemption, RentExemption) { - let queue_size = QueueAccount::size(address_queue_config.capacity as usize).unwrap(); +) -> Result<(RentExemption, RentExemption), ForesterError> { + let queue_size = QueueAccount::size(address_queue_config.capacity as usize) + .context("Failed to compute AddressV1 queue account size")?; let queue_rent_exempt_lamports = rpc .get_minimum_balance_for_rent_exemption(queue_size) .await - .unwrap(); + .context("Failed to fetch rent exemption for AddressV1 queue account")?; let tree_size = AddressMerkleTreeAccount::size( address_merkle_tree_config.height as usize, address_merkle_tree_config.changelog_size as usize, @@ -469,8 +507,8 @@ pub async fn get_rent_exemption_for_address_merkle_tree_and_queue( let merkle_tree_rent_exempt_lamports = rpc .get_minimum_balance_for_rent_exemption(tree_size) .await - .unwrap(); - ( + .context("Failed to fetch rent exemption for AddressV1 merkle tree account")?; + Ok(( RentExemption { lamports: merkle_tree_rent_exempt_lamports, size: tree_size, @@ -479,5 +517,5 @@ pub async fn get_rent_exemption_for_address_merkle_tree_and_queue( lamports: queue_rent_exempt_lamports, size: queue_size, }, - ) + )) } diff --git a/forester/src/slot_tracker.rs b/forester/src/slot_tracker.rs index 8fb6179c75..30d9842745 100644 --- a/forester/src/slot_tracker.rs +++ b/forester/src/slot_tracker.rs @@ -3,7 +3,7 @@ use std::{ atomic::{AtomicU64, Ordering}, Arc, }, - time::{SystemTime, UNIX_EPOCH}, + time::Instant, }; use light_client::rpc::Rpc; @@ -14,50 +14,50 @@ pub fn slot_duration() -> Duration { Duration::from_nanos(solana_sdk::genesis_config::GenesisConfig::default().ns_per_slot() as u64) } +fn slot_duration_secs() -> f64 { + static SLOT_DURATION_SECS: std::sync::OnceLock = std::sync::OnceLock::new(); + *SLOT_DURATION_SECS.get_or_init(|| slot_duration().as_secs_f64()) +} + #[derive(Debug)] pub struct SlotTracker { last_known_slot: AtomicU64, - last_update_time: AtomicU64, + last_update_nanos: AtomicU64, + reference_instant: Instant, update_interval: Duration, } impl SlotTracker { pub fn new(initial_slot: u64, update_interval: Duration) -> Self { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis() as u64; + let reference = Instant::now(); Self { last_known_slot: AtomicU64::new(initial_slot), - last_update_time: AtomicU64::new(now), + last_update_nanos: AtomicU64::new(0), + reference_instant: reference, update_interval, } } pub fn update(&self, new_slot: u64) { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis() as u64; + let elapsed_nanos = self.reference_instant.elapsed().as_nanos() as u64; self.last_known_slot.store(new_slot, Ordering::Release); - self.last_update_time.store(now, Ordering::Release); + self.last_update_nanos + .store(elapsed_nanos, Ordering::Release); } + #[inline] pub fn estimated_current_slot(&self) -> u64 { let last_slot = self.last_known_slot.load(Ordering::Acquire); - let last_update = self.last_update_time.load(Ordering::Acquire); - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_millis() as u64; - let elapsed = Duration::from_millis(now - last_update); - let estimated_slot = - last_slot + (elapsed.as_secs_f64() / slot_duration().as_secs_f64()) as u64; + let last_update_nanos = self.last_update_nanos.load(Ordering::Acquire); + let current_nanos = self.reference_instant.elapsed().as_nanos() as u64; + let elapsed_nanos = current_nanos.saturating_sub(last_update_nanos); + let elapsed_secs = elapsed_nanos as f64 / 1_000_000_000.0; + let estimated_slot = last_slot + (elapsed_secs / slot_duration_secs()) as u64; trace!( "Estimated current slot: {} (last known: {}, elapsed: {:?})", estimated_slot, last_slot, - elapsed + Duration::from_nanos(elapsed_nanos) ); estimated_slot } @@ -70,7 +70,6 @@ impl SlotTracker { } Err(e) => error!("Failed to get slot: {:?}", e), } - tokio::task::yield_now().await; tokio::time::sleep(self.update_interval).await; } } @@ -108,7 +107,6 @@ pub async fn wait_until_slot_reached( sleep_duration.as_secs_f64() ); - tokio::task::yield_now().await; sleep(sleep_duration).await; } diff --git a/forester/src/tree_data_sync.rs b/forester/src/tree_data_sync.rs index 16203077cb..3f0f2bd3ce 100644 --- a/forester/src/tree_data_sync.rs +++ b/forester/src/tree_data_sync.rs @@ -8,14 +8,45 @@ use light_batched_merkle_tree::merkle_tree::BatchedMerkleTreeAccount; use light_client::rpc::Rpc; use light_compressed_account::TreeType; use light_merkle_tree_metadata::merkle_tree::MerkleTreeMetadata; +use serde_json::json; use solana_sdk::{account::Account, pubkey::Pubkey}; -use tracing::trace; +use tracing::{debug, trace, warn}; use crate::{errors::AccountDeserializationError, Result}; +// Discriminators for filtering getProgramAccounts +// BatchedMerkleTreeAccount: b"BatchMta" +const BATCHED_TREE_DISCRIMINATOR: [u8; 8] = [66, 97, 116, 99, 104, 77, 116, 97]; +// StateMerkleTreeAccount: sha256("account:StateMerkleTreeAccount")[0..8] +const STATE_V1_DISCRIMINATOR: [u8; 8] = [172, 43, 172, 186, 29, 73, 219, 84]; +// AddressMerkleTreeAccount: sha256("account:AddressMerkleTreeAccount")[0..8] +const ADDRESS_V1_DISCRIMINATOR: [u8; 8] = [11, 161, 175, 9, 212, 229, 73, 73]; + +/// Fetch trees using filtered getProgramAccounts calls (optimized for remote RPCs). +/// Falls back to unfiltered fetch if the filtered approach fails. pub async fn fetch_trees(rpc: &R) -> Result> { + let rpc_url = rpc.get_url(); + + // Try filtered approach first (much faster for remote RPCs) + match fetch_trees_filtered(&rpc_url).await { + Ok(trees) => { + trace!("Fetched {} trees using filtered queries", trees.len()); + Ok(trees) + } + Err(e) => { + warn!( + "Filtered tree fetch failed, falling back to unfiltered: {:?}", + e + ); + fetch_trees_unfiltered(rpc).await + } + } +} + +/// Fetch trees without filters (original implementation, slower but more reliable) +pub async fn fetch_trees_unfiltered(rpc: &R) -> Result> { let program_id = account_compression::id(); - trace!("Fetching accounts for program: {}", program_id); + trace!("Fetching accounts for program (unfiltered): {}", program_id); Ok(rpc .get_program_accounts(&program_id) .await? @@ -24,6 +55,177 @@ pub async fn fetch_trees(rpc: &R) -> Result> { .collect()) } +/// Fetch trees using filtered getProgramAccounts calls with discriminator memcmp filters. +/// Makes 3 parallel requests (one per tree type) instead of fetching all accounts. +pub async fn fetch_trees_filtered(rpc_url: &str) -> Result> { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build()?; + let program_id = account_compression::id(); + + // Fetch all three types in parallel + let (batched_result, state_v1_result, address_v1_result) = tokio::join!( + fetch_accounts_with_discriminator( + &client, + rpc_url, + &program_id, + &BATCHED_TREE_DISCRIMINATOR + ), + fetch_accounts_with_discriminator(&client, rpc_url, &program_id, &STATE_V1_DISCRIMINATOR), + fetch_accounts_with_discriminator(&client, rpc_url, &program_id, &ADDRESS_V1_DISCRIMINATOR), + ); + + let mut all_trees = Vec::new(); + let mut errors = Vec::new(); + + // Process batched trees (V2) - need to distinguish state vs address + match batched_result { + Ok(accounts) => { + debug!("Fetched {} batched tree accounts", accounts.len()); + for (pubkey, mut account) in accounts { + // Try state first, then address + if let Ok(tree) = process_batch_state_account(&mut account, pubkey) { + all_trees.push(tree); + } else if let Ok(tree) = process_batch_address_account(&mut account, pubkey) { + all_trees.push(tree); + } + } + } + Err(e) => { + warn!("Failed to fetch batched trees: {:?}", e); + errors.push(format!("batched: {}", e)); + } + } + + // Process state V1 trees + match state_v1_result { + Ok(accounts) => { + debug!("Fetched {} state V1 tree accounts", accounts.len()); + for (pubkey, account) in accounts { + if let Ok(tree) = process_state_account(&account, pubkey) { + all_trees.push(tree); + } + } + } + Err(e) => { + warn!("Failed to fetch state V1 trees: {:?}", e); + errors.push(format!("state_v1: {}", e)); + } + } + + // Process address V1 trees + match address_v1_result { + Ok(accounts) => { + debug!("Fetched {} address V1 tree accounts", accounts.len()); + for (pubkey, account) in accounts { + if let Ok(tree) = process_address_account(&account, pubkey) { + all_trees.push(tree); + } + } + } + Err(e) => { + warn!("Failed to fetch address V1 trees: {:?}", e); + errors.push(format!("address_v1: {}", e)); + } + } + + // Only return error if all queries failed; empty-but-successful is Ok + if !errors.is_empty() && all_trees.is_empty() { + return Err(anyhow::anyhow!( + "All filtered queries failed: {}", + errors.join(", ") + )); + } + + Ok(all_trees) +} + +/// Fetch accounts from a program with a specific discriminator filter +async fn fetch_accounts_with_discriminator( + client: &reqwest::Client, + rpc_url: &str, + program_id: &Pubkey, + discriminator: &[u8; 8], +) -> Result> { + let discriminator_base58 = bs58::encode(discriminator).into_string(); + + let payload = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "getProgramAccounts", + "params": [ + program_id.to_string(), + { + "encoding": "base64", + "commitment": "confirmed", + "filters": [ + { + "memcmp": { + "offset": 0, + "bytes": discriminator_base58 + } + } + ] + } + ] + }); + + let response = client.post(rpc_url).json(&payload).send().await?; + + if !response.status().is_success() { + return Err(anyhow::anyhow!("HTTP error: {}", response.status())); + } + + let json_response: serde_json::Value = response.json().await?; + + if let Some(error) = json_response.get("error") { + return Err(anyhow::anyhow!("RPC error: {:?}", error)); + } + + let accounts_array = json_response + .get("result") + .and_then(|v| v.as_array()) + .ok_or_else(|| anyhow::anyhow!("Unexpected response format"))?; + + let mut accounts = Vec::with_capacity(accounts_array.len()); + + for account_value in accounts_array { + if let Some((pubkey, account)) = parse_account_from_json(account_value) { + accounts.push((pubkey, account)); + } + } + + Ok(accounts) +} + +/// Parse a single account from JSON RPC response +fn parse_account_from_json(value: &serde_json::Value) -> Option<(Pubkey, Account)> { + let pubkey_str = value.get("pubkey")?.as_str()?; + let pubkey: Pubkey = pubkey_str.parse().ok()?; + + let account_obj = value.get("account")?; + let lamports = account_obj.get("lamports")?.as_u64()?; + let owner_str = account_obj.get("owner")?.as_str()?; + let owner: Pubkey = owner_str.parse().ok()?; + let executable = account_obj.get("executable")?.as_bool().unwrap_or(false); + let rent_epoch = account_obj.get("rentEpoch")?.as_u64().unwrap_or(0); + + let data_array = account_obj.get("data")?.as_array()?; + let data_str = data_array.first()?.as_str()?; + let data = base64::decode(data_str).ok()?; + + Some(( + pubkey, + Account { + lamports, + data, + owner, + executable, + rent_epoch, + }, + )) +} + fn process_account(pubkey: Pubkey, mut account: Account) -> Option { process_state_account(&account, pubkey) .or_else(|_| process_batch_state_account(&mut account, pubkey)) @@ -103,10 +305,11 @@ fn create_tree_accounts( ); trace!( - "{:?} Merkle Tree account found. Pubkey: {}. Queue pubkey: {}", + "{:?} Merkle Tree account found. Pubkey: {}. Queue pubkey: {}. Rolledover: {}", tree_type, pubkey, - tree_accounts.queue + tree_accounts.queue, + tree_accounts.is_rolledover ); tree_accounts } diff --git a/forester/src/tree_finder.rs b/forester/src/tree_finder.rs deleted file mode 100644 index 30fdeb7cc1..0000000000 --- a/forester/src/tree_finder.rs +++ /dev/null @@ -1,73 +0,0 @@ -use std::sync::Arc; - -use forester_utils::{forester_epoch::TreeAccounts, rpc_pool::SolanaRpcPool}; -use light_client::rpc::Rpc; -use tokio::{ - sync::broadcast, - time::{interval, Duration}, -}; -use tracing::{error, trace}; - -use crate::{tree_data_sync::fetch_trees, Result}; - -pub struct TreeFinder { - rpc_pool: Arc>, - known_trees: Vec, - new_tree_sender: broadcast::Sender, - check_interval: Duration, -} - -impl TreeFinder { - pub fn new( - rpc_pool: Arc>, - initial_trees: Vec, - new_tree_sender: broadcast::Sender, - check_interval: Duration, - ) -> Self { - Self { - rpc_pool, - known_trees: initial_trees, - new_tree_sender, - check_interval, - } - } - - pub async fn run(&mut self) -> Result<()> { - let mut interval = interval(self.check_interval); - - loop { - interval.tick().await; - trace!("Checking for new trees"); - - match self.check_for_new_trees().await { - Ok(new_trees) => { - for tree in new_trees { - if let Err(e) = self.new_tree_sender.send(tree) { - error!("Failed to send new tree: {:?}", e); - } else { - trace!("New tree discovered: {:?}", tree); - self.known_trees.push(tree); - } - } - } - Err(e) => { - error!("Error checking for new trees: {:?}", e); - } - } - - tokio::task::yield_now().await; - } - } - - async fn check_for_new_trees(&self) -> Result> { - let rpc = self.rpc_pool.get_connection().await?; - let current_trees = fetch_trees(&*rpc).await?; - - let new_trees: Vec = current_trees - .into_iter() - .filter(|tree| !self.known_trees.contains(tree)) - .collect(); - - Ok(new_trees) - } -} diff --git a/forester/src/utils.rs b/forester/src/utils.rs index 164f0a7067..9590c63eb3 100644 --- a/forester/src/utils.rs +++ b/forester/src/utils.rs @@ -5,22 +5,25 @@ use light_registry::{ protocol_config::state::{ProtocolConfig, ProtocolConfigPda}, utils::get_protocol_config_pda_address, }; -use tracing::debug; +use tracing::{debug, warn}; -pub async fn get_protocol_config(rpc: &mut R) -> ProtocolConfig { +pub async fn get_protocol_config(rpc: &mut R) -> crate::Result { let authority_pda = get_protocol_config_pda_address(); let protocol_config_account = rpc .get_anchor_account::(&authority_pda.0) .await - .unwrap() - .unwrap(); + .map_err(|e| anyhow::anyhow!("Failed to fetch protocol config account: {}", e))? + .ok_or_else(|| anyhow::anyhow!("Protocol config account not found"))?; debug!("Protocol config account: {:?}", protocol_config_account); - protocol_config_account.config + Ok(protocol_config_account.config) } pub fn get_current_system_time_ms() -> u128 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis() + match SystemTime::now().duration_since(UNIX_EPOCH) { + Ok(d) => d.as_millis(), + Err(e) => { + warn!("SystemTime went backwards: {}", e); + 0 + } + } } diff --git a/forester/tests/e2e_test.rs b/forester/tests/e2e_test.rs index c1312ba6a8..a4452027d2 100644 --- a/forester/tests/e2e_test.rs +++ b/forester/tests/e2e_test.rs @@ -222,6 +222,7 @@ async fn e2e_test() { photon_rate_limit: None, send_tx_rate_limit: None, }, + lookup_table_address: None, retry_config: Default::default(), queue_config: Default::default(), indexer_config: Default::default(), @@ -236,9 +237,10 @@ async fn e2e_test() { skip_v2_state_trees: false, skip_v1_address_trees: false, skip_v2_address_trees: false, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, + queue_polling_mode: Default::default(), }, rpc_pool_config: RpcPoolConfig { max_size: 50, @@ -419,7 +421,9 @@ async fn e2e_test() { println!("seed {}", rng_seed); let rng = &mut StdRng::seed_from_u64(rng_seed); - let protocol_config = get_protocol_config(&mut rpc).await; + let protocol_config = get_protocol_config(&mut rpc) + .await + .expect("Failed to fetch protocol config"); let registration_phase_slot = get_registration_phase_start_slot(&mut rpc, &protocol_config).await; @@ -457,7 +461,10 @@ async fn e2e_test() { compressible_account_subscriber ); - execute_test_transactions( + let iterations: usize = 100; + + let test_iterations = execute_test_transactions( + iterations, &mut rpc, rng, &env, @@ -472,7 +479,12 @@ async fn e2e_test() { ) .await; - wait_for_work_report(&mut work_report_receiver, &state_tree_params).await; + wait_for_work_report( + &mut work_report_receiver, + &state_tree_params, + test_iterations, + ) + .await; // Verify root changes based on enabled tests if is_v1_state_test_enabled() { @@ -585,7 +597,12 @@ async fn get_initial_merkle_tree_state( .get_anchor_account::(merkle_tree_pubkey) .await .unwrap() - .unwrap(); + .unwrap_or_else(|| { + panic!( + "StateV1 merkle tree account not found: {}", + merkle_tree_pubkey + ) + }); let merkle_tree = get_concurrent_merkle_tree::( @@ -601,11 +618,20 @@ async fn get_initial_merkle_tree_state( (next_index, sequence_number, root) } TreeType::AddressV1 => { + println!( + "Fetching initial state for V1 address tree: {:?}", + merkle_tree_pubkey + ); let account = rpc .get_anchor_account::(merkle_tree_pubkey) .await .unwrap() - .unwrap(); + .unwrap_or_else(|| { + panic!( + "AddressV1 merkle tree account not found: {}", + merkle_tree_pubkey + ) + }); let merkle_tree = get_indexed_merkle_tree::< AddressMerkleTreeAccount, @@ -743,17 +769,6 @@ async fn verify_root_changed( ); } -async fn get_state_v2_batch_size(rpc: &mut R, merkle_tree_pubkey: &Pubkey) -> u64 { - let mut merkle_tree_account = rpc.get_account(*merkle_tree_pubkey).await.unwrap().unwrap(); - let merkle_tree = BatchedMerkleTreeAccount::state_from_bytes( - merkle_tree_account.data.as_mut_slice(), - &merkle_tree_pubkey.into(), - ) - .unwrap(); - - merkle_tree.get_metadata().queue_batches.batch_size -} - async fn setup_forester_pipeline( config: &ForesterConfig, ) -> ( @@ -791,14 +806,10 @@ async fn setup_forester_pipeline( async fn wait_for_work_report( work_report_receiver: &mut mpsc::Receiver, tree_params: &InitStateTreeAccountsInstructionData, + expected_minimum_processed_items: usize, ) { let batch_size = tree_params.output_queue_zkp_batch_size as usize; - // With increased test size, expect more processed items - let minimum_processed_items: usize = if is_v2_state_test_enabled() { - (tree_params.output_queue_batch_size as usize) * 4 // Expect at least 4 batches worth - } else { - tree_params.output_queue_batch_size as usize - }; + let mut total_processed_items: usize = 0; let timeout_duration = Duration::from_secs(DEFAULT_TIMEOUT_SECONDS); @@ -806,11 +817,11 @@ async fn wait_for_work_report( println!("Batch size: {}", batch_size); println!( "Minimum required processed items: {}", - minimum_processed_items + expected_minimum_processed_items ); let start_time = tokio::time::Instant::now(); - while total_processed_items < minimum_processed_items { + while total_processed_items < expected_minimum_processed_items { match timeout( timeout_duration.saturating_sub(start_time.elapsed()), work_report_receiver.recv(), @@ -820,6 +831,11 @@ async fn wait_for_work_report( Ok(Some(report)) => { println!("Received work report: {:?}", report); total_processed_items += report.processed_items; + + if total_processed_items >= expected_minimum_processed_items { + println!("Received required number of processed items."); + break; + } } Ok(None) => { println!("Work report channel closed unexpectedly"); @@ -834,15 +850,16 @@ async fn wait_for_work_report( println!("Total processed items: {}", total_processed_items); assert!( - total_processed_items >= minimum_processed_items, + total_processed_items >= expected_minimum_processed_items, "Processed fewer items ({}) than required ({})", total_processed_items, - minimum_processed_items + expected_minimum_processed_items ); } #[allow(clippy::too_many_arguments)] async fn execute_test_transactions( + iterations: usize, rpc: &mut R, rng: &mut StdRng, env: &TestAccounts, @@ -854,14 +871,7 @@ async fn execute_test_transactions( sender_batched_token_counter: &mut u64, address_v1_counter: &mut u64, address_v2_counter: &mut u64, -) { - let mut iterations = 4; - if is_v2_state_test_enabled() { - let batch_size = - get_state_v2_batch_size(rpc, &env.v2_state_trees[0].merkle_tree).await as usize; - iterations = batch_size * 2; - } - +) -> usize { println!("Executing {} test transactions", iterations); println!("==========================================="); for i in 0..iterations { @@ -966,6 +976,8 @@ async fn execute_test_transactions( println!("{} v2 address create: {:?}", i, sig_v2_addr); } } + + iterations } async fn mint_to( diff --git a/forester/tests/legacy/e2e_test.rs b/forester/tests/legacy/e2e_test.rs index 9dc712eb3f..fe9be114a5 100644 --- a/forester/tests/legacy/e2e_test.rs +++ b/forester/tests/legacy/e2e_test.rs @@ -482,7 +482,9 @@ async fn test_epoch_double_registration() { } let mut rpc = pool.get_connection().await.unwrap(); - let protocol_config = get_protocol_config(&mut *rpc).await; + let protocol_config = get_protocol_config(&mut *rpc) + .await + .expect("Failed to fetch protocol config"); let solana_slot = rpc.get_slot().await.unwrap(); let current_epoch = protocol_config.get_current_epoch(solana_slot); diff --git a/forester/tests/legacy/e2e_v1_test.rs b/forester/tests/legacy/e2e_v1_test.rs index 050ece14af..9c59d5a51e 100644 --- a/forester/tests/legacy/e2e_v1_test.rs +++ b/forester/tests/legacy/e2e_v1_test.rs @@ -479,7 +479,9 @@ async fn test_epoch_double_registration() { } let mut rpc = pool.get_connection().await.unwrap(); - let protocol_config = get_protocol_config(&mut *rpc).await; + let protocol_config = get_protocol_config(&mut *rpc) + .await + .expect("Failed to fetch protocol config"); let solana_slot = rpc.get_slot().await.unwrap(); let current_epoch = protocol_config.get_current_epoch(solana_slot); diff --git a/forester/tests/legacy/test_utils.rs b/forester/tests/legacy/test_utils.rs index 35c7e0e3fa..fb1c42a3ef 100644 --- a/forester/tests/legacy/test_utils.rs +++ b/forester/tests/legacy/test_utils.rs @@ -101,7 +101,7 @@ pub fn forester_config() -> ForesterConfig { skip_v2_state_trees: false, skip_v1_address_trees: false, skip_v2_address_trees: false, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, }, diff --git a/forester/tests/priority_fee_test.rs b/forester/tests/priority_fee_test.rs index 5656781122..96475c429e 100644 --- a/forester/tests/priority_fee_test.rs +++ b/forester/tests/priority_fee_test.rs @@ -62,6 +62,8 @@ async fn test_priority_fee_request() { max_concurrent_sends: 50, tx_cache_ttl_seconds: 15, ops_cache_ttl_seconds: 180, + confirmation_max_attempts: 30, + confirmation_poll_interval_ms: 1000, cu_limit: 1_000_000, enable_priority_fees: true, rpc_pool_size: 20, @@ -83,8 +85,10 @@ async fn test_priority_fee_request() { photon_rate_limit: None, send_tx_rate_limit: None, processor_mode: ProcessorMode::All, - tree_id: None, + queue_polling_mode: Default::default(), + tree_ids: vec![], enable_compressible: true, + lookup_table_address: None, }; let config = ForesterConfig::new_for_start(&args).expect("Failed to create config"); diff --git a/forester/tests/test_utils.rs b/forester/tests/test_utils.rs index 1471cfe0cc..5c1ce7a4b7 100644 --- a/forester/tests/test_utils.rs +++ b/forester/tests/test_utils.rs @@ -1,6 +1,7 @@ use std::time::Duration; use forester::{ + cli::QueuePollingMode, config::{ExternalServicesConfig, GeneralConfig, RpcPoolConfig}, metrics::register_metrics, telemetry::setup_telemetry, @@ -110,9 +111,10 @@ pub fn forester_config() -> ForesterConfig { skip_v2_state_trees: false, skip_v1_address_trees: false, skip_v2_address_trees: false, - tree_id: None, + tree_ids: vec![], sleep_after_processing_ms: 50, sleep_when_idle_ms: 100, + queue_polling_mode: QueuePollingMode::OnChain, }, rpc_pool_config: RpcPoolConfig { max_size: 50, @@ -128,6 +130,7 @@ pub fn forester_config() -> ForesterConfig { address_tree_data: vec![], state_tree_data: vec![], compressible_config: None, + lookup_table_address: None, } } @@ -328,6 +331,34 @@ pub async fn get_active_phase_start_slot( phases.active.start } +/// Get the active phase start slot for an epoch with enough time remaining. +/// If the current epoch's active phase has less than `min_slots_remaining` slots, +/// returns the next epoch's active phase start. +#[allow(dead_code)] +pub async fn get_next_active_phase_with_time( + rpc: &mut R, + protocol_config: &ProtocolConfig, + min_slots_remaining: u64, +) -> u64 { + let current_slot = rpc.get_slot().await.unwrap(); + let current_epoch = protocol_config.get_current_epoch(current_slot); + let phases = get_epoch_phases(protocol_config, current_epoch); + + // Check if current epoch has enough time remaining + let slots_remaining = phases.active.end.saturating_sub(current_slot); + if slots_remaining >= min_slots_remaining { + phases.active.start + } else { + // Use next epoch + let next_phases = get_epoch_phases(protocol_config, current_epoch + 1); + println!( + "Current epoch {} has only {} slots remaining, using epoch {} (active phase starts at slot {})", + current_epoch, slots_remaining, current_epoch + 1, next_phases.active.start + ); + next_phases.active.start + } +} + #[allow(dead_code)] pub async fn wait_for_slot(rpc: &mut LightClient, target_slot: u64) { while rpc.get_slot().await.unwrap() < target_slot { diff --git a/program-tests/registry-test/tests/tests.rs b/program-tests/registry-test/tests/tests.rs index 6d44bd8872..fd659a5302 100644 --- a/program-tests/registry-test/tests/tests.rs +++ b/program-tests/registry-test/tests/tests.rs @@ -890,6 +890,7 @@ async fn test_register_and_update_forester_pda() { &protocol_config, &forester_keypair, &forester_keypair.pubkey(), + None, ) .await .unwrap(); @@ -990,6 +991,7 @@ async fn test_register_and_update_forester_pda() { &protocol_config, &forester_keypair, &forester_keypair.pubkey(), + None, ) .await .unwrap(); diff --git a/program-tests/system-test/tests/test.rs b/program-tests/system-test/tests/test.rs index 2761ccc3f3..5768a9a66f 100644 --- a/program-tests/system-test/tests/test.rs +++ b/program-tests/system-test/tests/test.rs @@ -2,7 +2,11 @@ use account_compression::errors::AccountCompressionErrorCode; use anchor_lang::{AnchorSerialize, InstructionData, ToAccountMetas}; -use light_batched_merkle_tree::{errors::BatchedMerkleTreeError, queue::BatchedQueueAccount}; +use light_batched_merkle_tree::{ + errors::BatchedMerkleTreeError, + initialize_address_tree::InitAddressTreeAccountsInstructionData, + initialize_state_tree::InitStateTreeAccountsInstructionData, queue::BatchedQueueAccount, +}; use light_client::indexer::{AddressWithTree, Indexer}; use light_compressed_account::{ address::{derive_address, derive_address_legacy}, @@ -1656,6 +1660,11 @@ async fn regenerate_accounts() { }; let mut config = ProgramTestConfig::default_with_batched_trees(false); config.protocol_config = protocol_config; + + // Use testnet/devnet/mainnet tree configs (batch_size=15000, zkp_batch_size=500/250) + config.v2_state_tree_config = Some(InitStateTreeAccountsInstructionData::default()); + config.v2_address_tree_config = Some(InitAddressTreeAccountsInstructionData::default()); + let mut rpc = LightProgramTest::new(config).await.unwrap(); let env = rpc.test_accounts.clone(); let keypairs = for_regenerate_accounts(); diff --git a/program-tests/utils/src/e2e_test_env.rs b/program-tests/utils/src/e2e_test_env.rs index 481f103914..4025598554 100644 --- a/program-tests/utils/src/e2e_test_env.rs +++ b/program-tests/utils/src/e2e_test_env.rs @@ -835,10 +835,8 @@ where if response_result.status().is_success() { let body = response_result.text().await.unwrap(); let proof_json = deserialize_gnark_proof_json(&body).unwrap(); - let (proof_a, proof_b, proof_c) = - proof_from_json_struct(proof_json); - let (proof_a, proof_b, proof_c) = - compress_proof(&proof_a, &proof_b, &proof_c); + let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); + let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); let instruction_data = InstructionDataBatchNullifyInputs { new_root: circuit_inputs_new_root, compressed_proof: CompressedProof { @@ -1035,6 +1033,7 @@ where &self.protocol_config, &forester.keypair, &forester.keypair.pubkey(), + None, ) .await .unwrap() diff --git a/program-tests/utils/src/mock_batched_forester.rs b/program-tests/utils/src/mock_batched_forester.rs index 990f912d6c..732ed0af16 100644 --- a/program-tests/utils/src/mock_batched_forester.rs +++ b/program-tests/utils/src/mock_batched_forester.rs @@ -145,9 +145,9 @@ impl MockBatchedForester { } }; let proof = CompressedProof { - a: proof_result.0.a, - b: proof_result.0.b, - c: proof_result.0.c, + a: proof_result.0.proof.a, + b: proof_result.0.proof.b, + c: proof_result.0.proof.c, }; Ok((proof, proof_result.1)) } @@ -211,9 +211,9 @@ impl MockBatchedForester { .await?; let new_root = self.merkle_tree.root(); let proof = CompressedProof { - a: proof_result.0.a, - b: proof_result.0.b, - c: proof_result.0.c, + a: proof_result.0.proof.a, + b: proof_result.0.proof.b, + c: proof_result.0.proof.c, }; Ok((proof, new_root)) } @@ -319,9 +319,9 @@ impl MockBatchedAddressForester { } }; let proof = CompressedProof { - a: proof_result.0.a, - b: proof_result.0.b, - c: proof_result.0.c, + a: proof_result.0.proof.a, + b: proof_result.0.proof.b, + c: proof_result.0.proof.c, }; Ok((proof, proof_result.1)) } diff --git a/program-tests/utils/src/setup_forester.rs b/program-tests/utils/src/setup_forester.rs index 6f1fc17630..b82698212d 100644 --- a/program-tests/utils/src/setup_forester.rs +++ b/program-tests/utils/src/setup_forester.rs @@ -41,6 +41,7 @@ pub async fn setup_forester_and_advance_to_epoch( protocol_config, &test_keypairs.forester, &test_keypairs.forester.pubkey(), + None, ) .await? .ok_or_else(|| RpcError::CustomError("Failed to register epoch".to_string()))?; diff --git a/program-tests/utils/src/test_batch_forester.rs b/program-tests/utils/src/test_batch_forester.rs index 0fcc2441b6..416bb3b80c 100644 --- a/program-tests/utils/src/test_batch_forester.rs +++ b/program-tests/utils/src/test_batch_forester.rs @@ -182,9 +182,9 @@ pub async fn create_append_batch_ix_data( InstructionDataBatchAppendInputs { new_root, compressed_proof: CompressedProof { - a: proof.a, - b: proof.b, - c: proof.c, + a: proof.proof.a, + b: proof.proof.b, + c: proof.proof.c, }, } } @@ -311,9 +311,9 @@ pub async fn get_batched_nullify_ix_data( Ok(InstructionDataBatchNullifyInputs { new_root, compressed_proof: CompressedProof { - a: proof.a, - b: proof.b, - c: proof.c, + a: proof.proof.a, + b: proof.proof.b, + c: proof.proof.c, }, }) } @@ -734,9 +734,9 @@ pub async fn create_batch_update_address_tree_instruction_data_with_proof for ProverClientError { diff --git a/prover/client/src/helpers.rs b/prover/client/src/helpers.rs index 61e7dfb6d9..6ea223e79f 100644 --- a/prover/client/src/helpers.rs +++ b/prover/client/src/helpers.rs @@ -5,7 +5,6 @@ use light_sparse_merkle_tree::changelog::ChangelogEntry; use num_bigint::{BigInt, BigUint}; use num_traits::{Num, ToPrimitive}; use serde::Serialize; -use serde_json::json; pub fn get_project_root() -> Option { let output = Command::new("git") @@ -21,11 +20,9 @@ pub fn get_project_root() -> Option { } pub fn change_endianness(bytes: &[u8]) -> Vec { - let mut vec = Vec::new(); - for b in bytes.chunks(32) { - for byte in b.iter().rev() { - vec.push(*byte); - } + let mut vec = Vec::with_capacity(bytes.len()); + for chunk in bytes.chunks(32) { + vec.extend(chunk.iter().rev()); } vec } @@ -107,9 +104,5 @@ pub fn create_json_from_struct(json_struct: &T) -> String where T: Serialize, { - let json = json!(json_struct); - match serde_json::to_string_pretty(&json) { - Ok(json) => json, - Err(_) => panic!("Merkle tree data invalid"), - } + serde_json::to_string(json_struct).expect("JSON serialization failed for valid struct") } diff --git a/prover/client/src/proof.rs b/prover/client/src/proof.rs index f6b8f47ba3..c415a4d108 100644 --- a/prover/client/src/proof.rs +++ b/prover/client/src/proof.rs @@ -12,11 +12,6 @@ use solana_bn254::compression::prelude::{ convert_endianness, }; -pub struct ProofResult { - pub proof: ProofCompressed, - pub public_inputs: Vec<[u8; 32]>, -} - #[derive(Debug, Clone, Copy)] pub struct ProofCompressed { pub a: [u8; 32], @@ -24,6 +19,12 @@ pub struct ProofCompressed { pub c: [u8; 32], } +#[derive(Debug, Clone, Copy)] +pub struct ProofResult { + pub proof: ProofCompressed, + pub proof_duration_ms: u64, +} + impl From for CompressedProof { fn from(proof: ProofCompressed) -> Self { CompressedProof { diff --git a/prover/client/src/proof_client.rs b/prover/client/src/proof_client.rs index fe1d708baa..1d557407bd 100644 --- a/prover/client/src/proof_client.rs +++ b/prover/client/src/proof_client.rs @@ -10,6 +10,7 @@ use crate::{ errors::ProverClientError, proof::{ compress_proof, deserialize_gnark_proof_json, proof_from_json_struct, ProofCompressed, + ProofResult, }, proof_types::{ batch_address_append::{to_json, BatchAddressAppendInputs}, @@ -20,10 +21,13 @@ use crate::{ const MAX_RETRIES: u32 = 10; const BASE_RETRY_DELAY_SECS: u64 = 1; -const DEFAULT_POLLING_INTERVAL_SECS: u64 = 1; +const DEFAULT_POLLING_INTERVAL_MS: u64 = 100; const DEFAULT_MAX_WAIT_TIME_SECS: u64 = 600; const DEFAULT_LOCAL_SERVER: &str = "http://localhost:3001"; +const INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS: u64 = 200; +const INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS: u64 = 200; + #[derive(Debug, Deserialize)] #[serde(untagged)] pub enum ProofResponse { @@ -33,6 +37,14 @@ pub enum ProofResponse { }, } +#[derive(Debug)] +pub enum SubmitProofResult { + /// Job was queued, poll with this ID + Queued(String), + /// Proof was returned immediately (sync response) + Immediate(ProofResult), +} + #[derive(Debug, Deserialize)] pub struct JobStatusResponse { pub status: String, @@ -52,6 +64,7 @@ pub struct ProofClient { polling_interval: Duration, max_wait_time: Duration, api_key: Option, + initial_poll_delay: Duration, } impl ProofClient { @@ -59,9 +72,10 @@ impl ProofClient { Self { client: Client::new(), server_address: DEFAULT_LOCAL_SERVER.to_string(), - polling_interval: Duration::from_secs(DEFAULT_POLLING_INTERVAL_SECS), + polling_interval: Duration::from_millis(DEFAULT_POLLING_INTERVAL_MS), max_wait_time: Duration::from_secs(DEFAULT_MAX_WAIT_TIME_SECS), api_key: None, + initial_poll_delay: Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS), } } @@ -72,19 +86,90 @@ impl ProofClient { max_wait_time: Duration, api_key: Option, ) -> Self { + let initial_poll_delay = if api_key.is_some() { + Duration::from_millis(INITIAL_POLL_DELAY_LARGE_CIRCUIT_MS) + } else { + Duration::from_millis(INITIAL_POLL_DELAY_SMALL_CIRCUIT_MS) + }; + Self { client: Client::new(), server_address, polling_interval, max_wait_time, api_key, + initial_poll_delay, + } + } + + #[allow(unused)] + pub fn with_full_config( + server_address: String, + polling_interval: Duration, + max_wait_time: Duration, + api_key: Option, + initial_poll_delay: Duration, + ) -> Self { + Self { + client: Client::new(), + server_address, + polling_interval, + max_wait_time, + api_key, + initial_poll_delay, + } + } + + pub async fn submit_proof_async( + &self, + inputs_json: String, + circuit_type: &str, + ) -> Result { + debug!( + "Submitting async proof request for circuit type: {}", + circuit_type + ); + + let response = self.send_proof_request(&inputs_json).await?; + let status_code = response.status(); + let response_text = response.text().await.map_err(|e| { + ProverClientError::ProverServerError(format!("Failed to read response body: {}", e)) + })?; + + self.log_response(status_code, &response_text); + + match status_code { + reqwest::StatusCode::ACCEPTED => { + debug!("Received asynchronous job response"); + let job_response = self.parse_job_response(&response_text)?; + match job_response { + ProofResponse::Async { job_id, .. } => { + info!("Proof job queued with ID: {}", job_id); + Ok(SubmitProofResult::Queued(job_id)) + } + } + } + reqwest::StatusCode::OK => { + // Synchronous response - proof returned immediately + debug!("Received synchronous proof response"); + let proof = self.parse_proof_from_json(&response_text)?; + Ok(SubmitProofResult::Immediate(proof)) + } + _ => self.handle_error_response::(&response_text), } } + pub async fn poll_proof_completion( + &self, + job_id: String, + ) -> Result { + self.poll_for_result(&job_id, Duration::ZERO).await + } + pub async fn generate_proof( &self, inputs_json: String, - ) -> Result { + ) -> Result { let start_time = Instant::now(); let mut retries = 0; @@ -132,7 +217,7 @@ impl ProofClient { &self, inputs_json: &str, elapsed: Duration, - ) -> Result { + ) -> Result { let response = self.send_proof_request(inputs_json).await?; let status_code = response.status(); let response_text = response.text().await.map_err(|e| { @@ -173,7 +258,10 @@ impl ProofClient { fn log_response(&self, status_code: reqwest::StatusCode, response_text: &str) { if !status_code.is_success() { - error!("HTTP error: status={}, body={}", status_code, response_text); + error!( + "HTTP error: status={}, body={}, url={}", + status_code, response_text, self.server_address + ); } } @@ -182,7 +270,7 @@ impl ProofClient { status_code: reqwest::StatusCode, response_text: &str, start_elapsed: Duration, - ) -> Result { + ) -> Result { match status_code { reqwest::StatusCode::OK => self.parse_proof_from_json(response_text), reqwest::StatusCode::ACCEPTED => { @@ -204,7 +292,7 @@ impl ProofClient { &self, job_response: ProofResponse, start_elapsed: Duration, - ) -> Result { + ) -> Result { match job_response { ProofResponse::Async { job_id, .. } => { info!("Proof job queued with ID: {}", job_id); @@ -213,10 +301,7 @@ impl ProofClient { } } - fn handle_error_response( - &self, - response_text: &str, - ) -> Result { + fn handle_error_response(&self, response_text: &str) -> Result { if let Ok(error_response) = serde_json::from_str::(response_text) { error!( "Prover server error: {} - {}", @@ -244,6 +329,12 @@ impl ProofClient { return false; } + let is_constraint_error = + error_str.contains("constraint") || error_str.contains("is not satisfied"); + if is_constraint_error { + return false; + } + let is_retryable_error = error_str.contains("job_not_found") || error_str.contains("connection") || error_str.contains("timeout") @@ -265,15 +356,37 @@ impl ProofClient { &self, job_id: &str, start_elapsed: Duration, - ) -> Result { + ) -> Result { let poll_start_time = Instant::now(); let status_url = format!("{}/prove/status?job_id={}", self.server_address, job_id); info!("Starting to poll for job {} at URL: {}", job_id, status_url); + debug!( + "Waiting {:?} before first poll to allow prover to persist job {}", + self.initial_poll_delay, job_id + ); + sleep(self.initial_poll_delay).await; + let mut poll_count = 0; let mut transient_error_count = 0; + if poll_count > 1 { + let wasted_polls = poll_count - 1; + let suggested_delay_ms = self.initial_poll_delay.as_millis() as u64 + + (wasted_polls as u64 * self.polling_interval.as_millis() as u64); + + warn!( + "Job {} required {} polls (wasted {} polls before completion). \ + Consider increasing initial_poll_delay from {}ms to ~{}ms for better efficiency.", + job_id, + poll_count, + wasted_polls, + self.initial_poll_delay.as_millis(), + suggested_delay_ms + ); + } + loop { poll_count += 1; let poll_elapsed = poll_start_time.elapsed(); @@ -422,7 +535,7 @@ impl ProofClient { job_id: &str, elapsed: Duration, poll_count: u32, - ) -> Result, ProverClientError> { + ) -> Result, ProverClientError> { trace!( "Poll #{} for job {}: status='{}', message='{}'", poll_count, @@ -478,9 +591,10 @@ impl ProofClient { &self, result: Option, job_id: &str, - ) -> Result { + ) -> Result { match result { Some(result) => { + trace!("Job {} has result, parsing proof JSON", job_id); trace!("Job {} has result, parsing proof JSON", job_id); let proof_json = serde_json::to_string(&result).map_err(|e| { error!("Failed to serialize result for job {}: {}", job_id, e); @@ -506,25 +620,58 @@ impl ProofClient { error_str.contains("503") || error_str.contains("502") || error_str.contains("500") } - fn parse_proof_from_json(&self, json_str: &str) -> Result { - let proof_json = deserialize_gnark_proof_json(json_str).map_err(|e| { + fn parse_proof_from_json(&self, json_str: &str) -> Result { + // Try parsing as ProofWithTiming format (new format with timing) + #[derive(Deserialize)] + struct ProofWithTimingJson { + proof: serde_json::Value, + proof_duration_ms: u64, + } + + let (proof_json_value, proof_duration_ms) = if let Ok(proof_with_timing) = + serde_json::from_str::(json_str) + { + (proof_with_timing.proof, proof_with_timing.proof_duration_ms) + } else { + // Fall back to plain proof format (old format without timing) + let proof_value: serde_json::Value = serde_json::from_str(json_str).map_err(|e| { + ProverClientError::ProverServerError(format!("Failed to parse proof JSON: {}", e)) + })?; + (proof_value, 0) + }; + + // Check if proof is null - this indicates the prover failed to generate a proof + if proof_json_value.is_null() { + return Err(ProverClientError::ProverServerError( + "Prover returned null proof - proof generation failed on server side".to_string(), + )); + } + + let proof_json_str = serde_json::to_string(&proof_json_value).map_err(|e| { + ProverClientError::ProverServerError(format!("Failed to serialize proof JSON: {}", e)) + })?; + + let proof_json = deserialize_gnark_proof_json(&proof_json_str).map_err(|e| { ProverClientError::ProverServerError(format!("Failed to deserialize proof JSON: {}", e)) })?; let (proof_a, proof_b, proof_c) = proof_from_json_struct(proof_json); let (proof_a, proof_b, proof_c) = compress_proof(&proof_a, &proof_b, &proof_c); - Ok(ProofCompressed { - a: proof_a, - b: proof_b, - c: proof_c, + Ok(ProofResult { + proof: ProofCompressed { + a: proof_a, + b: proof_b, + c: proof_c, + }, + proof_duration_ms, }) } pub async fn generate_batch_address_append_proof( &self, inputs: BatchAddressAppendInputs, - ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> { + ) -> Result<(ProofResult, [u8; 32]), ProverClientError> { let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>(&inputs.new_root)?; let inputs_json = to_json(&inputs); let proof = self.generate_proof(inputs_json).await?; @@ -534,7 +681,7 @@ impl ProofClient { pub async fn generate_batch_append_proof( &self, circuit_inputs: BatchAppendsCircuitInputs, - ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> { + ) -> Result<(ProofResult, [u8; 32]), ProverClientError> { let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>( &circuit_inputs.new_root.to_biguint().unwrap(), )?; @@ -546,7 +693,7 @@ impl ProofClient { pub async fn generate_batch_update_proof( &self, circuit_inputs: BatchUpdateCircuitInputs, - ) -> Result<(ProofCompressed, [u8; 32]), ProverClientError> { + ) -> Result<(ProofResult, [u8; 32]), ProverClientError> { let new_root = light_hasher::bigint::bigint_to_be_bytes_array::<32>( &circuit_inputs.new_root.to_biguint().unwrap(), )?; diff --git a/prover/client/src/proof_types/batch_address_append/json.rs b/prover/client/src/proof_types/batch_address_append/json.rs index fcb245a874..cd31a326e8 100644 --- a/prover/client/src/proof_types/batch_address_append/json.rs +++ b/prover/client/src/proof_types/batch_address_append/json.rs @@ -35,6 +35,12 @@ pub struct BatchAddressAppendInputsJson { pub start_index: usize, #[serde(rename = "treeHeight")] pub tree_height: usize, + /// Tree pubkey for fair queuing - used to prevent starvation when multiple trees have proofs pending + #[serde(rename = "treeId", skip_serializing_if = "Option::is_none")] + pub tree_id: Option, + /// Batch index for ordering - ensures batches are processed in sequence within a tree + #[serde(rename = "batchIndex", skip_serializing_if = "Option::is_none")] + pub batch_index: Option, } impl BatchAddressAppendInputsJson { @@ -78,9 +84,23 @@ impl BatchAddressAppendInputsJson { public_input_hash: big_uint_to_string(&inputs.public_input_hash), start_index: inputs.start_index, tree_height: inputs.tree_height, + tree_id: None, + batch_index: None, } } + /// Set the tree ID for fair queuing across multiple trees + pub fn with_tree_id(mut self, tree_id: String) -> Self { + self.tree_id = Some(tree_id); + self + } + + /// Set the batch index for ordering within a tree + pub fn with_batch_index(mut self, batch_index: u64) -> Self { + self.batch_index = Some(batch_index); + self + } + #[allow(clippy::inherent_to_string)] pub fn to_string(&self) -> String { create_json_from_struct(&self) diff --git a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs index baeba1bb03..f80e8d49e4 100644 --- a/prover/client/src/proof_types/batch_address_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_address_append/proof_inputs.rs @@ -1,5 +1,9 @@ +use std::collections::HashMap; + use light_hasher::{ - bigint::bigint_to_be_bytes_array, hash_chain::create_hash_chain_from_array, Poseidon, + bigint::bigint_to_be_bytes_array, + hash_chain::{create_hash_chain_from_array, create_hash_chain_from_slice}, + Poseidon, }; use light_indexed_array::{array::IndexedElement, changelog::RawIndexedElement}; use light_sparse_merkle_tree::{ @@ -11,6 +15,78 @@ use num_bigint::BigUint; use crate::{errors::ProverClientError, helpers::compute_root_from_merkle_proof}; +#[derive(Default)] +struct ProofCache { + cache: HashMap<(usize, usize), [u8; 32]>, +} + +impl ProofCache { + fn add_entry(&mut self, entry: &ChangelogEntry) { + let index = entry.index(); + for level in 0..HEIGHT { + if let Some(hash) = entry.path[level] { + let node_index = index >> level; + self.cache.insert((level, node_index), hash); + } + } + } + + fn get_sibling_hash(&self, level: usize, sibling_node_index: usize) -> Option<[u8; 32]> { + self.cache.get(&(level, sibling_node_index)).copied() + } +} + +struct ChangelogProofPatcher { + cache: ProofCache, + hits: usize, + misses: usize, + overwrites: usize, +} + +impl ChangelogProofPatcher { + fn new(changelog: &[ChangelogEntry]) -> Self { + let mut cache = ProofCache::default(); + for entry in changelog.iter() { + cache.add_entry::(entry); + } + Self { + cache, + hits: 0, + misses: 0, + overwrites: 0, + } + } + + fn update_proof( + &mut self, + leaf_index: usize, + proof: &mut [[u8; 32]; HEIGHT], + ) { + for (level, proof_element) in proof.iter_mut().enumerate() { + let my_node_index = leaf_index >> level; + let sibling_node_index = my_node_index ^ 1; + if let Some(hash) = self.cache.get_sibling_hash(level, sibling_node_index) { + self.hits += 1; + if *proof_element != hash { + self.overwrites += 1; + } + *proof_element = hash; + } else { + self.misses += 1; + } + } + } + + fn push_changelog_entry( + &mut self, + changelog: &mut Vec>, + entry: ChangelogEntry, + ) { + self.cache.add_entry::(&entry); + changelog.push(entry); + } +} + #[derive(Debug, Clone)] pub struct BatchAddressAppendInputs { pub batch_size: usize, @@ -29,6 +105,84 @@ pub struct BatchAddressAppendInputs { pub tree_height: usize, } +impl BatchAddressAppendInputs { + #[allow(clippy::too_many_arguments)] + pub fn new( + batch_size: usize, + leaves_hashchain: [u8; 32], + low_element_values: &[[u8; 32]], + low_element_indices: &[u64], + low_element_next_indices: &[u64], + low_element_next_values: &[[u8; 32]], + low_element_proofs: Vec>, + new_element_values: &[[u8; 32]], + new_element_proofs: Vec>, + new_root: [u8; 32], + old_root: [u8; 32], + start_index: usize, + ) -> Result { + let hash_chain_inputs = [ + old_root, + new_root, + leaves_hashchain, + bigint_to_be_bytes_array::<32>(&start_index.into())?, + ]; + let public_input_hash = create_hash_chain_from_array(hash_chain_inputs)?; + + let low_element_proofs_bigint: Vec> = low_element_proofs + .into_iter() + .map(|proof| { + proof + .into_iter() + .map(|p| BigUint::from_bytes_be(&p)) + .collect() + }) + .collect(); + + let new_element_proofs_bigint: Vec> = new_element_proofs + .into_iter() + .map(|proof| { + proof + .into_iter() + .map(|p| BigUint::from_bytes_be(&p)) + .collect() + }) + .collect(); + + Ok(Self { + batch_size, + hashchain_hash: BigUint::from_bytes_be(&leaves_hashchain), + low_element_values: low_element_values + .iter() + .map(|v| BigUint::from_bytes_be(v)) + .collect(), + low_element_indices: low_element_indices + .iter() + .map(|&i| BigUint::from(i)) + .collect(), + low_element_next_indices: low_element_next_indices + .iter() + .map(|&i| BigUint::from(i)) + .collect(), + low_element_next_values: low_element_next_values + .iter() + .map(|v| BigUint::from_bytes_be(v)) + .collect(), + low_element_proofs: low_element_proofs_bigint, + new_element_values: new_element_values + .iter() + .map(|v| BigUint::from_bytes_be(v)) + .collect(), + new_element_proofs: new_element_proofs_bigint, + new_root: BigUint::from_bytes_be(&new_root), + old_root: BigUint::from_bytes_be(&old_root), + public_input_hash: BigUint::from_bytes_be(&public_input_hash), + start_index, + tree_height: HEIGHT, + }) + } +} + #[allow(clippy::too_many_arguments)] pub fn get_batch_address_append_circuit_inputs( next_index: usize, @@ -45,12 +199,36 @@ pub fn get_batch_address_append_circuit_inputs( changelog: &mut Vec>, indexed_changelog: &mut Vec>, ) -> Result { - // 1. input all elements of a batch. - // 2. iterate over elements 0..end_index - // 3. only use elements start_index..end_index in the circuit (we need to - // iterate over elements prior to start index to create changelog entries to - // patch subsequent element proofs. The indexer won't be caught up yet.) let new_element_values = new_element_values[0..zkp_batch_size].to_vec(); + + let computed_hashchain = create_hash_chain_from_slice(&new_element_values).map_err(|e| { + ProverClientError::GenericError(format!("Failed to compute hashchain: {}", e)) + })?; + if computed_hashchain != leaves_hashchain { + tracing::error!( + "hashchain mismatch: computed {:?} != indexer {:?} (batch_size={}, next_index={})", + &computed_hashchain[..8], + &leaves_hashchain[..8], + zkp_batch_size, + next_index + ); + for (i, addr) in new_element_values.iter().take(3).enumerate() { + tracing::error!(" address[{}] = {:?}[..8]", i, &addr[..8]); + } + return Err(ProverClientError::HashchainMismatch { + computed: computed_hashchain, + expected: leaves_hashchain, + batch_size: zkp_batch_size, + next_index, + }); + } + tracing::debug!( + "Hashchain validated OK: {:?}[..4] (batch_size={}, next_index={})", + &computed_hashchain[..4], + zkp_batch_size, + next_index + ); + let mut new_root = [0u8; 32]; let mut low_element_circuit_merkle_proofs = vec![]; let mut new_element_circuit_merkle_proofs = vec![]; @@ -60,6 +238,11 @@ pub fn get_batch_address_append_circuit_inputs( let mut patched_low_element_values: Vec<[u8; 32]> = Vec::new(); let mut patched_low_element_indices: Vec = Vec::new(); + let mut patcher = ChangelogProofPatcher::new::(changelog); + + let is_first_batch = indexed_changelog.is_empty(); + let mut expected_root_for_low = current_root; + for i in 0..new_element_values.len() { let mut changelog_index = 0; @@ -87,13 +270,17 @@ pub fn get_batch_address_append_circuit_inputs( &mut low_element_next_value, &mut low_element_proof, ) - .unwrap(); + .map_err(|e| { + ProverClientError::ProofPatchFailed(format!( + "failed to patch indexed changelogs: {}", + e + )) + })?; patched_low_element_next_values - .push(bigint_to_be_bytes_array::<32>(&low_element_next_value).unwrap()); + .push(bigint_to_be_bytes_array::<32>(&low_element_next_value)?); patched_low_element_next_indices.push(low_element.next_index()); patched_low_element_indices.push(low_element.index); - patched_low_element_values - .push(bigint_to_be_bytes_array::<32>(&low_element.value).unwrap()); + patched_low_element_values.push(bigint_to_be_bytes_array::<32>(&low_element.value)?); let new_low_element: IndexedElement = IndexedElement { index: low_element.index, @@ -101,38 +288,101 @@ pub fn get_batch_address_append_circuit_inputs( next_index: new_element.index, }; let new_low_element_raw = RawIndexedElement { - value: bigint_to_be_bytes_array::<32>(&new_low_element.value).unwrap(), + value: bigint_to_be_bytes_array::<32>(&new_low_element.value)?, next_index: new_low_element.next_index, - next_value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), + next_value: bigint_to_be_bytes_array::<32>(&new_element.value)?, index: new_low_element.index, }; - { - for change_log_entry in changelog.iter().skip(changelog_index) { - change_log_entry - .update_proof(low_element.index(), &mut low_element_proof) - .unwrap(); + let intermediate_root = { + let mut low_element_proof_arr: [[u8; 32]; HEIGHT] = low_element_proof + .clone() + .try_into() + .map_err(|v: Vec<[u8; 32]>| { + ProverClientError::ProofPatchFailed(format!( + "low element proof length mismatch: expected {}, got {}", + HEIGHT, + v.len() + )) + })?; + patcher.update_proof::(low_element.index(), &mut low_element_proof_arr); + let merkle_proof = low_element_proof_arr; + + if is_first_batch { + let old_low_leaf_hash = low_element + .hash::(&low_element_next_value) + .map_err(|e| { + ProverClientError::GenericError(format!( + "Failed to hash old low element: {}", + e + )) + })?; + let (computed_root, _) = compute_root_from_merkle_proof::( + old_low_leaf_hash, + &merkle_proof, + low_element.index as u32, + ); + if computed_root != expected_root_for_low { + let low_value_bytes = bigint_to_be_bytes_array::<32>(&low_element.value) + .map_err(|e| { + ProverClientError::GenericError(format!( + "Failed to serialize low element value: {}", + e + )) + })?; + let low_next_value_bytes = + bigint_to_be_bytes_array::<32>(&low_element_next_value).map_err(|e| { + ProverClientError::GenericError(format!( + "Failed to serialize low element next value: {}", + e + )) + })?; + return Err(ProverClientError::GenericError(format!( + "element {}: low proof mismatch (computed {:?}[..4] != expected {:?}[..4], low_idx={}, low_value={:?}[..4], low_next={:?}[..4])", + i, + &computed_root[..4], + &expected_root_for_low[..4], + low_element.index, + &low_value_bytes[..4], + &low_next_value_bytes[..4], + ))); + } } - let merkle_proof = low_element_proof.clone().try_into().unwrap(); + let new_low_leaf_hash = new_low_element .hash::(&new_element.value) - .unwrap(); - let (_updated_root, changelog_entry) = compute_root_from_merkle_proof::( - new_low_leaf_hash, - &merkle_proof, - new_low_element.index as u32, - ); - changelog.push(changelog_entry); + .map_err(|e| { + ProverClientError::GenericError(format!( + "Failed to hash new low element: {}", + e + )) + })?; + let (low_update_intermediate_root, changelog_entry) = + compute_root_from_merkle_proof::( + new_low_leaf_hash, + &merkle_proof, + new_low_element.index as u32, + ); + + patcher.push_changelog_entry::(changelog, changelog_entry); low_element_circuit_merkle_proofs.push( merkle_proof .iter() .map(|hash| BigUint::from_bytes_be(hash)) .collect(), ); - } + + low_update_intermediate_root + }; let low_element_changelog_entry = IndexedChangelogEntry { element: new_low_element_raw, - proof: low_element_proof.as_slice()[..HEIGHT].try_into().unwrap(), + proof: low_element_proof.as_slice()[..HEIGHT] + .try_into() + .map_err(|_| { + ProverClientError::ProofPatchFailed( + "low_element_proof slice conversion failed".to_string(), + ) + })?, changelog_index: indexed_changelog.len(), //change_log_index, }; @@ -142,25 +392,85 @@ pub fn get_batch_address_append_circuit_inputs( let new_element_next_value = low_element_next_value; let new_element_leaf_hash = new_element .hash::(&new_element_next_value) - .unwrap(); + .map_err(|e| { + ProverClientError::GenericError(format!("Failed to hash new element: {}", e)) + })?; + + let sparse_root_before = sparse_merkle_tree.root(); + let sparse_next_idx_before = sparse_merkle_tree.get_next_index(); + let mut merkle_proof_array = sparse_merkle_tree.append(new_element_leaf_hash); let current_index = next_index + i; - for change_log_entry in changelog.iter() { - change_log_entry - .update_proof(current_index, &mut merkle_proof_array) - .unwrap(); - } + patcher.update_proof::(current_index, &mut merkle_proof_array); let (updated_root, changelog_entry) = compute_root_from_merkle_proof( new_element_leaf_hash, &merkle_proof_array, current_index as u32, ); + + if i == 0 && changelog.len() == 1 { + if sparse_next_idx_before != current_index { + return Err(ProverClientError::GenericError(format!( + "sparse index mismatch: sparse tree next_index={} but expected current_index={}", + sparse_next_idx_before, current_index + ))); + } + + if sparse_root_before != current_root { + return Err(ProverClientError::GenericError(format!( + "sparse root mismatch: sparse tree root {:?}[..4] != current_root {:?}[..4] \ + (next_index={}). The subtrees from indexer may be stale.", + &sparse_root_before[..4], + ¤t_root[..4], + next_index + ))); + } + } + + if is_first_batch { + let zero_hash = [0u8; 32]; + let (root_with_zero, _) = compute_root_from_merkle_proof::( + zero_hash, + &merkle_proof_array, + current_index as u32, + ); + if root_with_zero != intermediate_root { + tracing::error!( + "ELEMENT {} NEW_PROOF MISMATCH: proof + ZERO = {:?}[..4] but expected \ + intermediate_root = {:?}[..4] (index={}, low_idx={})", + i, + &root_with_zero[..4], + &intermediate_root[..4], + current_index, + low_element.index + ); + return Err(ProverClientError::GenericError(format!( + "ELEMENT {} NEW_PROOF MISMATCH: proof + ZERO = {:?}[..4] but expected \ + intermediate_root = {:?}[..4] (index={}, low_idx={}). Patched proof is incorrect.", + i, + &root_with_zero[..4], + &intermediate_root[..4], + current_index, + low_element.index + ))); + } + if i == 0 { + tracing::info!( + "VALIDATION_PASS: element 0 new_element proof OK \ + (intermediate_root {:?}[..4] -> updated_root {:?}[..4])", + &intermediate_root[..4], + &updated_root[..4] + ); + } + expected_root_for_low = updated_root; + } + new_root = updated_root; - changelog.push(changelog_entry); + patcher.push_changelog_entry::(changelog, changelog_entry); new_element_circuit_merkle_proofs.push( merkle_proof_array .iter() @@ -169,9 +479,9 @@ pub fn get_batch_address_append_circuit_inputs( ); let new_element_raw = RawIndexedElement { - value: bigint_to_be_bytes_array::<32>(&new_element.value).unwrap(), + value: bigint_to_be_bytes_array::<32>(&new_element.value)?, next_index: new_element.next_index, - next_value: bigint_to_be_bytes_array::<32>(&new_element_next_value).unwrap(), + next_value: bigint_to_be_bytes_array::<32>(&new_element_next_value)?, index: new_element.index, }; @@ -188,7 +498,7 @@ pub fn get_batch_address_append_circuit_inputs( current_root, new_root, leaves_hashchain, - bigint_to_be_bytes_array::<32>(&next_index.into()).unwrap(), + bigint_to_be_bytes_array::<32>(&next_index.into())?, ]; for (idx, ((low_value, new_value), high_value)) in patched_low_element_values @@ -211,6 +521,22 @@ pub fn get_batch_address_append_circuit_inputs( let public_input_hash = create_hash_chain_from_array(hash_chain_inputs)?; + tracing::debug!( + "Address proof patcher stats: hits={}, misses={}, overwrites={}, changelog_len={}, indexed_changelog_len={}", + patcher.hits, + patcher.misses, + patcher.overwrites, + changelog.len(), + indexed_changelog.len() + ); + if patcher.hits == 0 && !changelog.is_empty() { + tracing::warn!( + "Address proof patcher had 0 cache hits despite non-empty changelog (changelog_len={}, indexed_changelog_len={})", + changelog.len(), + indexed_changelog.len() + ); + } + Ok(BatchAddressAppendInputs { batch_size: patched_low_element_values.len(), hashchain_hash: BigUint::from_bytes_be(&leaves_hashchain), diff --git a/prover/client/src/proof_types/batch_append/json.rs b/prover/client/src/proof_types/batch_append/json.rs index 917d496220..7f68d0899c 100644 --- a/prover/client/src/proof_types/batch_append/json.rs +++ b/prover/client/src/proof_types/batch_append/json.rs @@ -29,6 +29,12 @@ pub struct BatchAppendInputsJson { height: u32, #[serde(rename = "batchSize")] batch_size: u32, + /// Tree pubkey for fair queuing - used to prevent starvation when multiple trees have proofs pending + #[serde(rename = "treeId", skip_serializing_if = "Option::is_none")] + tree_id: Option, + /// Batch index for ordering - ensures batches are processed in sequence within a tree + #[serde(rename = "batchIndex", skip_serializing_if = "Option::is_none")] + batch_index: Option, } impl BatchAppendInputsJson { @@ -49,9 +55,23 @@ impl BatchAppendInputsJson { .collect(), height: inputs.height, batch_size: inputs.batch_size, + tree_id: None, + batch_index: None, } } + /// Set the tree ID for fair queuing across multiple trees + pub fn with_tree_id(mut self, tree_id: String) -> Self { + self.tree_id = Some(tree_id); + self + } + + /// Set the batch index for ordering within a tree + pub fn with_batch_index(mut self, batch_index: u64) -> Self { + self.batch_index = Some(batch_index); + self + } + #[allow(clippy::inherent_to_string)] pub fn to_string(&self) -> String { create_json_from_struct(&self) diff --git a/prover/client/src/proof_types/batch_append/proof_inputs.rs b/prover/client/src/proof_types/batch_append/proof_inputs.rs index bec593ec1b..ef0327ac1d 100644 --- a/prover/client/src/proof_types/batch_append/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_append/proof_inputs.rs @@ -32,7 +32,7 @@ impl BatchAppendsCircuitInputs { pub fn new( tree_result: BatchTreeUpdateResult, start_index: u32, - leaves: Vec<[u8; 32]>, + leaves: &[[u8; 32]], leaves_hashchain: [u8; 32], batch_size: u32, ) -> Result { diff --git a/prover/client/src/proof_types/batch_update/json.rs b/prover/client/src/proof_types/batch_update/json.rs index 51dfd72d20..d03a93dc47 100644 --- a/prover/client/src/proof_types/batch_update/json.rs +++ b/prover/client/src/proof_types/batch_update/json.rs @@ -31,6 +31,12 @@ pub struct BatchUpdateProofInputsJson { pub batch_size: u32, #[serde(rename(serialize = "txHashes"))] pub tx_hashes: Vec, + /// Tree pubkey for fair queuing - used to prevent starvation when multiple trees have proofs pending + #[serde(rename = "treeId", skip_serializing_if = "Option::is_none")] + pub tree_id: Option, + /// Batch index for ordering - ensures batches are processed in sequence within a tree + #[serde(rename = "batchIndex", skip_serializing_if = "Option::is_none")] + pub batch_index: Option, } #[derive(Serialize, Debug)] @@ -79,9 +85,23 @@ impl BatchUpdateProofInputsJson { height, batch_size, tx_hashes, + tree_id: None, + batch_index: None, } } + /// Set the tree ID for fair queuing across multiple trees + pub fn with_tree_id(mut self, tree_id: String) -> Self { + self.tree_id = Some(tree_id); + self + } + + /// Set the batch index for ordering within a tree + pub fn with_batch_index(mut self, batch_index: u64) -> Self { + self.batch_index = Some(batch_index); + self + } + #[allow(clippy::inherent_to_string)] pub fn to_string(&self) -> String { create_json_from_struct(&self) diff --git a/prover/client/src/proof_types/batch_update/proof_inputs.rs b/prover/client/src/proof_types/batch_update/proof_inputs.rs index 7c3f1be65e..e63f50ff23 100644 --- a/prover/client/src/proof_types/batch_update/proof_inputs.rs +++ b/prover/client/src/proof_types/batch_update/proof_inputs.rs @@ -7,9 +7,6 @@ use crate::{ helpers::{bigint_to_u8_32, compute_root_from_merkle_proof}, }; -/// Result of batch tree updates, containing proofs and root transitions. -/// This mirrors `forester_utils::staging_tree::BatchUpdateResult` but is defined -/// here to avoid a dependency cycle. #[derive(Clone, Debug)] pub struct BatchTreeUpdateResult { pub old_leaves: Vec<[u8; 32]>, @@ -40,10 +37,10 @@ impl BatchUpdateCircuitInputs { pub fn new( tree_result: BatchTreeUpdateResult, - tx_hashes: Vec<[u8; 32]>, - leaves: Vec<[u8; 32]>, + tx_hashes: &[[u8; 32]], + leaves: &[[u8; 32]], leaves_hashchain: [u8; 32], - path_indices: Vec, + path_indices: &[u32], batch_size: u32, ) -> Result { let batch_size_usize = batch_size as usize; @@ -104,7 +101,7 @@ impl BatchUpdateCircuitInputs { .map(|leaf| BigInt::from_bytes_be(Sign::Plus, leaf)) .collect(), merkle_proofs: circuit_merkle_proofs, - path_indices, + path_indices: path_indices.to_vec(), height: HEIGHT as u32, batch_size, }) diff --git a/prover/server/go.mod b/prover/server/go.mod index f2be71318d..5176cb6e68 100644 --- a/prover/server/go.mod +++ b/prover/server/go.mod @@ -1,6 +1,6 @@ module light/light-prover -go 1.25.4 +go 1.25.5 require ( github.com/consensys/gnark v0.14.0 diff --git a/prover/server/main.go b/prover/server/main.go index 1b7122ffb7..2ddb32db7a 100644 --- a/prover/server/main.go +++ b/prover/server/main.go @@ -9,7 +9,7 @@ import ( "light/light-prover/logging" "light/light-prover/prover/common" v1 "light/light-prover/prover/v1" - "light/light-prover/prover/v2" + v2 "light/light-prover/prover/v2" "light/light-prover/server" "os" "os/signal" @@ -600,7 +600,7 @@ func runCli() { &cli.BoolFlag{ Name: "auto-download", Usage: "Automatically download missing key files", - Value: true, + Value: false, }, &cli.StringFlag{ Name: "download-url", @@ -640,33 +640,37 @@ func runCli() { Str("keys_dir", keysDirPath). Msg("Initializing lazy key manager") - if preloadKeys == "all" { - logging.Logger().Info().Msg("Preloading all keys") - if err := keyManager.PreloadAll(); err != nil { - return fmt.Errorf("failed to preload all keys: %w", err) - } - } else if preloadKeys != "none" { - preloadRunMode, err := parseRunMode(preloadKeys) - if err != nil { - return fmt.Errorf("invalid --preload-keys value: %s (must be none, all, or a valid run mode: rpc, forester, forester-test, full, full-test, local-rpc)", preloadKeys) - } - logging.Logger().Info().Str("run_mode", string(preloadRunMode)).Msg("Preloading keys for run mode") - if err := keyManager.PreloadForRunMode(preloadRunMode); err != nil { - return fmt.Errorf("failed to preload keys for run mode: %w", err) + // Preload keys asynchronously to allow health checks to pass during startup + preloadAsync := func() { + if preloadKeys == "all" { + logging.Logger().Info().Msg("Preloading all keys (async)") + if err := keyManager.PreloadAll(); err != nil { + logging.Logger().Error().Err(err).Msg("Failed to preload all keys") + } + } else if preloadKeys != "none" { + preloadRunMode, err := parseRunMode(preloadKeys) + if err != nil { + logging.Logger().Error().Err(err).Str("value", preloadKeys).Msg("Invalid --preload-keys value") + } else { + logging.Logger().Info().Str("run_mode", string(preloadRunMode)).Msg("Preloading keys for run mode (async)") + if err := keyManager.PreloadForRunMode(preloadRunMode); err != nil { + logging.Logger().Error().Err(err).Msg("Failed to preload keys for run mode") + } + } } - } - if len(preloadCircuits) > 0 { - logging.Logger().Info().Strs("circuits", preloadCircuits).Msg("Preloading specific circuits") - if err := keyManager.PreloadCircuits(preloadCircuits); err != nil { - return fmt.Errorf("failed to preload circuits: %w", err) + if len(preloadCircuits) > 0 { + logging.Logger().Info().Strs("circuits", preloadCircuits).Msg("Preloading specific circuits (async)") + if err := keyManager.PreloadCircuits(preloadCircuits); err != nil { + logging.Logger().Error().Err(err).Msg("Failed to preload circuits") + } } - } - stats := keyManager.GetStats() - logging.Logger().Info(). - Interface("stats", stats). - Msg("Key manager initialized") + stats := keyManager.GetStats() + logging.Logger().Info(). + Interface("stats", stats). + Msg("Key preloading completed") + } redisURL := context.String("redis-url") if redisURL == "" { @@ -715,20 +719,38 @@ func runCli() { logging.Logger().Info().Msg("Starting queue workers") - updateWorker := server.NewUpdateQueueWorker(redisQueue, keyManager) - workers = append(workers, updateWorker) - go updateWorker.Start() + enabledCircuits := context.StringSlice("circuit") + enabledCircuitsMap := make(map[string]bool) + for _, c := range enabledCircuits { + enabledCircuitsMap[c] = true + } + + startAll := len(enabledCircuits) == 0 + var workersStarted []string - appendWorker := server.NewAppendQueueWorker(redisQueue, keyManager) - workers = append(workers, appendWorker) - go appendWorker.Start() + if startAll || enabledCircuitsMap["update"] || enabledCircuitsMap["update-test"] { + updateWorker := server.NewUpdateQueueWorker(redisQueue, keyManager) + workers = append(workers, updateWorker) + go updateWorker.Start() + workersStarted = append(workersStarted, "update") + } - addressAppendWorker := server.NewAddressAppendQueueWorker(redisQueue, keyManager) - workers = append(workers, addressAppendWorker) - go addressAppendWorker.Start() + if startAll || enabledCircuitsMap["append"] || enabledCircuitsMap["append-test"] { + appendWorker := server.NewAppendQueueWorker(redisQueue, keyManager) + workers = append(workers, appendWorker) + go appendWorker.Start() + workersStarted = append(workersStarted, "append") + } + + if startAll || enabledCircuitsMap["address-append"] || enabledCircuitsMap["address-append-test"] { + addressAppendWorker := server.NewAddressAppendQueueWorker(redisQueue, keyManager) + workers = append(workers, addressAppendWorker) + go addressAppendWorker.Start() + workersStarted = append(workersStarted, "address-append") + } logging.Logger().Info(). - Strs("workers_started", []string{"update", "append", "address-append"}). + Strs("workers_started", workersStarted). Msg("Queue workers started") } @@ -757,6 +779,8 @@ func runCli() { return fmt.Errorf("at least one of server or queue mode must be enabled") } + go preloadAsync() + sigint := make(chan os.Signal, 1) signal.Notify(sigint, os.Interrupt) <-sigint diff --git a/prover/server/prover/common/circuit_utils.go b/prover/server/prover/common/circuit_utils.go index 4f15a352f8..d994181fa3 100644 --- a/prover/server/prover/common/circuit_utils.go +++ b/prover/server/prover/common/circuit_utils.go @@ -15,6 +15,12 @@ type Proof struct { Proof groth16.Proof } +// ProofWithTiming wraps a proof with timing information for metrics +type ProofWithTiming struct { + Proof *Proof `json:"proof"` + ProofDurationMs int64 `json:"proof_duration_ms"` +} + type MerkleProofSystem struct { InclusionTreeHeight uint32 InclusionNumberOfCompressedAccounts uint32 diff --git a/prover/server/prover/common/proof_request_meta.go b/prover/server/prover/common/proof_request_meta.go index 2602bc086c..6b2dc1c65d 100644 --- a/prover/server/prover/common/proof_request_meta.go +++ b/prover/server/prover/common/proof_request_meta.go @@ -14,6 +14,11 @@ type ProofRequestMeta struct { TreeHeight uint32 NumInputs uint32 NumAddresses uint32 + // TreeID is the merkle tree pubkey - used for fair queuing across trees + TreeID string + // BatchIndex is the batch sequence number within a tree - used to process batches in order + // Lower batch indices should be processed first to enable sequential transaction submission + BatchIndex int64 } // ParseProofRequestMeta parses a JSON input and extracts CircuitType, tree heights, and additional metrics. @@ -73,6 +78,19 @@ func ParseProofRequestMeta(data []byte) (ProofRequestMeta, error) { numAddresses = len(nonInclusionInputs) } + // Extract TreeID for fair queuing + treeID := "" + if id, ok := rawInput["treeId"].(string); ok { + treeID = id + } + + // Extract BatchIndex for ordering proofs within a tree + // Default to -1 to indicate no batch index (legacy requests) + batchIndex := int64(-1) + if idx, ok := rawInput["batchIndex"].(float64); ok { + batchIndex = int64(idx) + } + return ProofRequestMeta{ Version: version, CircuitType: CircuitType(circuitType), @@ -80,5 +98,7 @@ func ParseProofRequestMeta(data []byte) (ProofRequestMeta, error) { AddressTreeHeight: addressTreeHeight, NumInputs: uint32(numInputs), NumAddresses: uint32(numAddresses), + TreeID: treeID, + BatchIndex: batchIndex, }, nil } diff --git a/prover/server/prover/v1/combined_proving_system.go b/prover/server/prover/v1/combined_proving_system.go index 492673c505..ddeb394ecf 100644 --- a/prover/server/prover/v1/combined_proving_system.go +++ b/prover/server/prover/v1/combined_proving_system.go @@ -103,10 +103,6 @@ func ProveCombined(ps *common.MerkleProofSystem, params *CombinedParameters) (*c circuit := InitializeCombinedCircuit(ps.InclusionTreeHeight, ps.InclusionNumberOfCompressedAccounts, ps.NonInclusionTreeHeight, ps.NonInclusionNumberOfCompressedAccounts) for i := 0; i < int(ps.InclusionNumberOfCompressedAccounts); i++ { - logging.Logger().Debug().Msgf("v1.ProveCombined: Inclusion[%d] Root=%v Leaf=%v PathIndex=%v", - i, params.InclusionParameters.Inputs[i].Root, - params.InclusionParameters.Inputs[i].Leaf, - params.InclusionParameters.Inputs[i].PathIndex) circuit.Inclusion.Roots[i] = params.InclusionParameters.Inputs[i].Root circuit.Inclusion.Leaves[i] = params.InclusionParameters.Inputs[i].Leaf circuit.Inclusion.InPathIndices[i] = params.InclusionParameters.Inputs[i].PathIndex @@ -117,14 +113,6 @@ func ProveCombined(ps *common.MerkleProofSystem, params *CombinedParameters) (*c } for i := 0; i < int(ps.NonInclusionNumberOfCompressedAccounts); i++ { - logging.Logger().Debug().Msgf("v1.ProveCombined: NonInclusion[%d] Root=%v Value=%v", - i, params.NonInclusionParameters.Inputs[i].Root, - params.NonInclusionParameters.Inputs[i].Value) - logging.Logger().Debug().Msgf("v1.ProveCombined: NonInclusion[%d] LeafLowerRangeValue=%v LeafHigherRangeValue=%v PathIndex=%v", - i, params.NonInclusionParameters.Inputs[i].LeafLowerRangeValue, - params.NonInclusionParameters.Inputs[i].LeafHigherRangeValue, - params.NonInclusionParameters.Inputs[i].PathIndex) - circuit.NonInclusion.Roots[i] = params.NonInclusionParameters.Inputs[i].Root circuit.NonInclusion.Values[i] = params.NonInclusionParameters.Inputs[i].Value circuit.NonInclusion.LeafLowerRangeValues[i] = params.NonInclusionParameters.Inputs[i].LeafLowerRangeValue diff --git a/prover/server/prover/v1/non_inclusion_proving_system.go b/prover/server/prover/v1/non_inclusion_proving_system.go index 8064acb6df..b31d800e93 100644 --- a/prover/server/prover/v1/non_inclusion_proving_system.go +++ b/prover/server/prover/v1/non_inclusion_proving_system.go @@ -64,7 +64,6 @@ func ProveNonInclusion(ps *common.MerkleProofSystem, params *NonInclusionParamet inPathIndices := make([]frontend.Variable, ps.NonInclusionNumberOfCompressedAccounts) for i := 0; i < int(ps.NonInclusionNumberOfCompressedAccounts); i++ { - logging.Logger().Debug().Msgf("ProveNonInclusion: Input[%d] NextIndex=%d", i, params.Inputs[i].NextIndex) roots[i] = params.Inputs[i].Root values[i] = params.Inputs[i].Value leafLowerRangeValues[i] = params.Inputs[i].LeafLowerRangeValue diff --git a/prover/server/prover/v2/marshal_batch_append_with_proofs.go b/prover/server/prover/v2/marshal_batch_append_with_proofs.go index 8fba0f59e3..86a7b04a2f 100644 --- a/prover/server/prover/v2/marshal_batch_append_with_proofs.go +++ b/prover/server/prover/v2/marshal_batch_append_with_proofs.go @@ -2,6 +2,7 @@ package v2 import ( "encoding/json" + "fmt" "light/light-prover/prover/common" "math/big" ) @@ -90,6 +91,12 @@ func (p *BatchAppendParameters) updateWithJSON(params BatchAppendInputsJSON) err return err } + // Validate array lengths match to prevent index out of range panic + if len(params.Leaves) != len(params.OldLeaves) { + return fmt.Errorf("array length mismatch: leaves=%d, oldLeaves=%d", + len(params.Leaves), len(params.OldLeaves)) + } + p.OldLeaves = make([]*big.Int, len(params.OldLeaves)) p.Leaves = make([]*big.Int, len(params.Leaves)) for i := 0; i < len(params.Leaves); i++ { diff --git a/prover/server/prover/v2/marshal_update.go b/prover/server/prover/v2/marshal_update.go index 47e0955896..f26fccf486 100644 --- a/prover/server/prover/v2/marshal_update.go +++ b/prover/server/prover/v2/marshal_update.go @@ -113,6 +113,12 @@ func (p *BatchUpdateParameters) UpdateWithJSON(params BatchUpdateProofInputsJSON return err } + // Validate array lengths match to prevent index out of range panic + if len(params.Leaves) != len(params.TxHashes) || len(params.Leaves) != len(params.OldLeaves) { + return fmt.Errorf("array length mismatch: leaves=%d, txHashes=%d, oldLeaves=%d", + len(params.Leaves), len(params.TxHashes), len(params.OldLeaves)) + } + p.TxHashes = make([]*big.Int, len(params.TxHashes)) p.Leaves = make([]*big.Int, len(params.Leaves)) p.OldLeaves = make([]*big.Int, len(params.OldLeaves)) diff --git a/prover/server/prover/v2/test_data_helpers.go b/prover/server/prover/v2/test_data_helpers.go index 3bdf357f72..d6c670f082 100644 --- a/prover/server/prover/v2/test_data_helpers.go +++ b/prover/server/prover/v2/test_data_helpers.go @@ -57,7 +57,7 @@ func BuildTestNonInclusionTree(depth int, numberOfCompressedAccounts int, random var values = make([]*big.Int, numberOfCompressedAccounts) var roots = make([]*big.Int, numberOfCompressedAccounts) - for i := 0; i < numberOfCompressedAccounts; i++ { + for i := range numberOfCompressedAccounts { var value = big.NewInt(0) var leafLower = big.NewInt(0) var leafUpper = big.NewInt(2) diff --git a/prover/server/publish_podman.sh b/prover/server/publish_podman.sh new file mode 100755 index 0000000000..b0ad5f711c --- /dev/null +++ b/prover/server/publish_podman.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +PROJECT_ID=$(gcloud config get-value project) +REGION=europe-west1 +REPO_NAME=light +IMAGE_NAME=prover-light +TAG=latest +FULL_IMAGE=$REGION-docker.pkg.dev/$PROJECT_ID/$REPO_NAME/$IMAGE_NAME:$TAG + +# Authenticate podman with Google Artifact Registry +gcloud auth print-access-token | podman login -u oauth2accesstoken --password-stdin $REGION-docker.pkg.dev + +# Build for amd64 +podman build --platform linux/amd64 -t $IMAGE_NAME:$TAG -f Dockerfile.light . + +# Tag and push +podman tag $IMAGE_NAME:$TAG $FULL_IMAGE +podman push $FULL_IMAGE diff --git a/prover/server/publish_prover.sh b/prover/server/publish_prover.sh new file mode 100755 index 0000000000..a34afa7441 --- /dev/null +++ b/prover/server/publish_prover.sh @@ -0,0 +1,19 @@ +# 2.0.23 + +PROJECT_ID=$(gcloud config get-value project) +REGION=europe-west1 +REPO_NAME=light +TAG="${1:-latest}" # Usage: ./publish_prover.sh [version] + +docker buildx build --platform linux/amd64,linux/arm64 -f Dockerfile.light -t prover-light:$TAG --load . +docker tag prover-light:$TAG $REGION-docker.pkg.dev/$PROJECT_ID/$REPO_NAME/prover-light:$TAG +docker push $REGION-docker.pkg.dev/$PROJECT_ID/$REPO_NAME/prover-light:$TAG + +# Deploy to GKE +CLUSTER_NAME="prover-gcloud-500" +CLUSTER_ZONE="us-central1-a" +IMAGE="$REGION-docker.pkg.dev/$PROJECT_ID/$REPO_NAME/prover-light:$TAG" +echo "Deploying $IMAGE to GKE cluster: $CLUSTER_NAME" +gcloud container clusters get-credentials $CLUSTER_NAME --zone=$CLUSTER_ZONE --project=$PROJECT_ID +kubectl set image deployment/prover-universal prover=$IMAGE -n prover +kubectl rollout status deployment/prover-universal -n prover diff --git a/prover/server/server/metrics.go b/prover/server/server/metrics.go index 4d05fa23b8..fc778626a3 100644 --- a/prover/server/server/metrics.go +++ b/prover/server/server/metrics.go @@ -1,10 +1,12 @@ package server import ( + "runtime" "time" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "light/light-prover/logging" ) var ( @@ -33,6 +35,14 @@ var ( []string{"circuit_type", "error_type"}, ) + ProofPanicsTotal = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: "prover_proof_panics_total", + Help: "Total number of panics recovered during proof processing", + }, + []string{"circuit_type"}, + ) + QueueWaitTime = promauto.NewHistogramVec( prometheus.HistogramOpts{ Name: "prover_queue_wait_time_seconds", @@ -82,19 +92,58 @@ var ( }, []string{"circuit_type"}, ) + + // Memory metrics for proof generation + ProofMemoryUsage = promauto.NewHistogramVec( + prometheus.HistogramOpts{ + Name: "prover_proof_memory_usage_bytes", + Help: "Memory allocated during proof generation (heap alloc delta)", + Buckets: prometheus.ExponentialBuckets(1024*1024*100, 2, 12), // 100MB to ~200GB + }, + []string{"circuit_type"}, + ) + + ProofPeakMemory = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "prover_proof_peak_memory_bytes", + Help: "Peak memory observed during proof generation by circuit type", + }, + []string{"circuit_type"}, + ) + + SystemMemoryUsage = promauto.NewGaugeVec( + prometheus.GaugeOpts{ + Name: "prover_system_memory_bytes", + Help: "System memory statistics", + }, + []string{"type"}, // heap_alloc, heap_sys, heap_inuse, sys + ) ) type MetricTimer struct { - start time.Time - circuitType string + start time.Time + circuitType string + startHeapAlloc uint64 } func StartProofTimer(circuitType string) *MetricTimer { ProofRequestsTotal.WithLabelValues(circuitType).Inc() ActiveJobs.Inc() + + // Capture memory state before proof generation + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + // Update system memory gauges + SystemMemoryUsage.WithLabelValues("heap_alloc").Set(float64(memStats.HeapAlloc)) + SystemMemoryUsage.WithLabelValues("heap_sys").Set(float64(memStats.HeapSys)) + SystemMemoryUsage.WithLabelValues("heap_inuse").Set(float64(memStats.HeapInuse)) + SystemMemoryUsage.WithLabelValues("sys").Set(float64(memStats.Sys)) + return &MetricTimer{ - start: time.Now(), - circuitType: circuitType, + start: time.Now(), + circuitType: circuitType, + startHeapAlloc: memStats.HeapAlloc, } } @@ -102,11 +151,54 @@ func (t *MetricTimer) ObserveDuration() { duration := time.Since(t.start).Seconds() ProofGenerationDuration.WithLabelValues(t.circuitType).Observe(duration) ActiveJobs.Dec() + + // Capture memory state after proof generation + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + + // Calculate memory delta (may be negative due to GC, use max with 0) + memDelta := int64(memStats.HeapAlloc) - int64(t.startHeapAlloc) + if memDelta < 0 { + memDelta = 0 + } + + // Record memory usage for this proof + ProofMemoryUsage.WithLabelValues(t.circuitType).Observe(float64(memDelta)) + + // Update peak memory if this is higher + currentPeak, _ := ProofPeakMemory.GetMetricWithLabelValues(t.circuitType) + if currentPeak != nil { + // Note: Gauge doesn't have a Get method, so we track via histogram max + } + + // Update system memory gauges + SystemMemoryUsage.WithLabelValues("heap_alloc").Set(float64(memStats.HeapAlloc)) + SystemMemoryUsage.WithLabelValues("heap_sys").Set(float64(memStats.HeapSys)) + SystemMemoryUsage.WithLabelValues("heap_inuse").Set(float64(memStats.HeapInuse)) + SystemMemoryUsage.WithLabelValues("sys").Set(float64(memStats.Sys)) + + // Log memory usage for debugging + logging.Logger().Info(). + Str("circuit_type", t.circuitType). + Float64("duration_sec", duration). + Uint64("start_heap_mb", t.startHeapAlloc/1024/1024). + Uint64("end_heap_mb", memStats.HeapAlloc/1024/1024). + Int64("delta_mb", memDelta/1024/1024). + Uint64("sys_mb", memStats.Sys/1024/1024). + Msg("Proof generation completed with memory stats") } func (t *MetricTimer) ObserveError(errorType string) { ProofGenerationErrors.WithLabelValues(t.circuitType, errorType).Inc() ActiveJobs.Dec() + + // Still record memory stats on error + var memStats runtime.MemStats + runtime.ReadMemStats(&memStats) + SystemMemoryUsage.WithLabelValues("heap_alloc").Set(float64(memStats.HeapAlloc)) + SystemMemoryUsage.WithLabelValues("heap_sys").Set(float64(memStats.HeapSys)) + SystemMemoryUsage.WithLabelValues("heap_inuse").Set(float64(memStats.HeapInuse)) + SystemMemoryUsage.WithLabelValues("sys").Set(float64(memStats.Sys)) } func RecordJobComplete(success bool) { diff --git a/prover/server/server/queue.go b/prover/server/server/queue.go index 267a36682f..314849b2a8 100644 --- a/prover/server/server/queue.go +++ b/prover/server/server/queue.go @@ -2,15 +2,25 @@ package server import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "fmt" "light/light-prover/logging" "light/light-prover/prover/common" "time" + "github.com/google/uuid" "github.com/redis/go-redis/v9" ) +const ( + // ResultsIndexKey is the Redis hash that maps inputHash → jobID + ResultsIndexKey = "zk_results_index" + // FailedIndexKey is the Redis hash that maps inputHash → jobID + FailedIndexKey = "zk_failed_index" +) + type RedisQueue struct { Client *redis.Client Ctx context.Context @@ -22,16 +32,35 @@ func NewRedisQueue(redisURL string) (*RedisQueue, error) { return nil, fmt.Errorf("failed to parse Redis URL: %w", err) } + // Configure connection pool and timeouts for Cloud Run + VPC connector reliability + opts.PoolSize = 500 // Connection pool size per instance (increased for high load) + opts.MinIdleConns = 10 // Keep some connections warm + opts.DialTimeout = 10 * time.Second // Timeout for establishing new connections + opts.ReadTimeout = 30 * time.Second // Timeout for read operations (BLPOP can be slow) + opts.WriteTimeout = 10 * time.Second // Timeout for write operations + opts.PoolTimeout = 15 * time.Second // Timeout for getting connection from pool + opts.ConnMaxIdleTime = 5 * time.Minute // Close idle connections after this time + opts.MaxRetries = 3 // Retry failed commands + client := redis.NewClient(opts) ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() if err := client.Ping(ctx).Err(); err != nil { return nil, fmt.Errorf("failed to connect to Redis: %w", err) } + logging.Logger().Info(). + Int("pool_size", opts.PoolSize). + Int("min_idle_conns", opts.MinIdleConns). + Dur("dial_timeout", opts.DialTimeout). + Dur("read_timeout", opts.ReadTimeout). + Dur("write_timeout", opts.WriteTimeout). + Int("max_retries", opts.MaxRetries). + Msg("Redis client configured with connection pool") + return &RedisQueue{Client: client, Ctx: context.Background()}, nil } @@ -41,19 +70,100 @@ func (rq *RedisQueue) EnqueueProof(queueName string, job *ProofJob) error { return fmt.Errorf("failed to marshal job: %w", err) } - err = rq.Client.RPush(rq.Ctx, queueName, data).Err() + // Use tree-specific sub-queue for fair queuing if TreeID is set + actualQueueName := queueName + if job.TreeID != "" && isFairQueueEnabled(queueName) { + actualQueueName = fmt.Sprintf("%s:%s", queueName, job.TreeID) + // Track this tree in the trees set for round-robin + treesSetKey := fmt.Sprintf("%s:trees", queueName) + rq.Client.SAdd(rq.Ctx, treesSetKey, job.TreeID) + } + + err = rq.Client.RPush(rq.Ctx, actualQueueName, data).Err() if err != nil { return fmt.Errorf("failed to enqueue job: %w", err) } logging.Logger().Info(). Str("job_id", job.ID). - Str("queue", queueName). + Str("queue", actualQueueName). + Str("tree_id", job.TreeID). + Str("redis_addr", rq.Client.Options().Addr). Msg("Job enqueued successfully") return nil } +// isFairQueueEnabled returns true for queues that support fair queuing per tree +func isFairQueueEnabled(queueName string) bool { + return queueName == "zk_update_queue" || + queueName == "zk_append_queue" || + queueName == "zk_address_append_queue" +} + +// StoreJobMeta stores job metadata when a job is submitted to enable reliable status lookups. +// This ensures the status endpoint can find the job even before a worker picks it up. +// TTL is set to 1 hour to match result TTL. +func (rq *RedisQueue) StoreJobMeta(jobID string, queueName string, circuitType string) error { + key := fmt.Sprintf("zk_job_meta_%s", jobID) + meta := map[string]interface{}{ + "queue": queueName, + "circuit_type": circuitType, + "submitted_at": time.Now(), + "status": "queued", + } + data, err := json.Marshal(meta) + if err != nil { + return fmt.Errorf("failed to marshal job meta: %w", err) + } + + err = rq.Client.Set(rq.Ctx, key, data, 1*time.Hour).Err() + if err != nil { + return fmt.Errorf("failed to store job meta: %w", err) + } + + logging.Logger().Debug(). + Str("job_id", jobID). + Str("queue", queueName). + Str("circuit_type", circuitType). + Str("redis_addr", rq.Client.Options().Addr). + Msg("Stored job metadata for status tracking") + + return nil +} + +// GetJobMeta retrieves job metadata by job ID. +// Returns nil if the job metadata doesn't exist. +func (rq *RedisQueue) GetJobMeta(jobID string) (map[string]interface{}, error) { + key := fmt.Sprintf("zk_job_meta_%s", jobID) + result, err := rq.Client.Get(rq.Ctx, key).Result() + if err == redis.Nil { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to get job meta: %w", err) + } + + var meta map[string]interface{} + if err := json.Unmarshal([]byte(result), &meta); err != nil { + return nil, fmt.Errorf("failed to unmarshal job meta: %w", err) + } + + return meta, nil +} + +// DeleteJobMeta removes job metadata when a job completes or fails. +func (rq *RedisQueue) DeleteJobMeta(jobID string) error { + key := fmt.Sprintf("zk_job_meta_%s", jobID) + return rq.Client.Del(rq.Ctx, key).Err() +} + func (rq *RedisQueue) DequeueProof(queueName string, timeout time.Duration) (*ProofJob, error) { + // Check if this queue supports fair queuing + if isFairQueueEnabled(queueName) { + return rq.dequeueWithFairQueuing(queueName, timeout) + } + + // Standard dequeue for non-fair queues result, err := rq.Client.BLPop(rq.Ctx, timeout, queueName).Result() if err != nil { if err == redis.Nil { @@ -75,6 +185,207 @@ func (rq *RedisQueue) DequeueProof(queueName string, timeout time.Duration) (*Pr return &job, nil } +// dequeueWithFairQueuing implements round-robin dequeuing across tree-specific sub-queues +// Within each tree's queue, it prioritizes jobs with lower batch_index to ensure sequential processing +func (rq *RedisQueue) dequeueWithFairQueuing(queueName string, timeout time.Duration) (*ProofJob, error) { + treesSetKey := fmt.Sprintf("%s:trees", queueName) + lastTreeKey := fmt.Sprintf("%s:last_tree", queueName) + + // Get all trees with pending jobs + trees, err := rq.Client.SMembers(rq.Ctx, treesSetKey).Result() + if err != nil { + return nil, fmt.Errorf("failed to get trees set: %w", err) + } + + // If no trees with jobs, fall back to main queue (for jobs without tree_id) + if len(trees) == 0 { + result, err := rq.Client.BLPop(rq.Ctx, timeout, queueName).Result() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, fmt.Errorf("failed to dequeue job: %w", err) + } + if len(result) < 2 { + return nil, fmt.Errorf("invalid result from Redis") + } + var job ProofJob + err = json.Unmarshal([]byte(result[1]), &job) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal job: %w", err) + } + return &job, nil + } + + // Get the last processed tree to start round-robin from next + lastTree, _ := rq.Client.Get(rq.Ctx, lastTreeKey).Result() + + // Find starting index for round-robin + startIdx := 0 + for i, tree := range trees { + if tree == lastTree { + startIdx = (i + 1) % len(trees) + break + } + } + + // Try each tree in round-robin order + for i := range len(trees) { + idx := (startIdx + i) % len(trees) + tree := trees[idx] + subQueueName := fmt.Sprintf("%s:%s", queueName, tree) + + // Get job with lowest batch_index from this tree's queue + job, err := rq.dequeueLowestBatchIndex(subQueueName) + if err == redis.Nil || job == nil { + // Queue empty, remove tree from set + rq.Client.SRem(rq.Ctx, treesSetKey, tree) + continue + } + if err != nil { + logging.Logger().Warn(). + Err(err). + Str("queue", subQueueName). + Msg("Error getting lowest batch_index job from tree sub-queue") + continue + } + + // Update last processed tree for next round-robin + rq.Client.Set(rq.Ctx, lastTreeKey, tree, 1*time.Hour) + + // Check if queue is now empty and remove from trees set + queueLen, _ := rq.Client.LLen(rq.Ctx, subQueueName).Result() + if queueLen == 0 { + rq.Client.SRem(rq.Ctx, treesSetKey, tree) + } + + logging.Logger().Debug(). + Str("job_id", job.ID). + Str("tree_id", tree). + Int64("batch_index", job.BatchIndex). + Str("queue", subQueueName). + Int("trees_count", len(trees)). + Msg("Dequeued job with fair queuing and batch_index priority") + + return job, nil + } + + // All tree queues were empty, try main queue as fallback + result, err := rq.Client.BLPop(rq.Ctx, timeout, queueName).Result() + if err != nil { + if err == redis.Nil { + return nil, nil + } + return nil, fmt.Errorf("failed to dequeue job: %w", err) + } + if len(result) < 2 { + return nil, fmt.Errorf("invalid result from Redis") + } + var job ProofJob + err = json.Unmarshal([]byte(result[1]), &job) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal job: %w", err) + } + return &job, nil +} + +// BatchIndexScanLimit is the maximum number of items to scan when looking for the lowest batch_index. +const BatchIndexScanLimit = 100 + +// dequeueLowestBatchIndex finds and removes the job with the lowest batch_index from the queue. +// This ensures that batches are processed in order within each tree, enabling the forester +// to send transactions sequentially as proofs complete. +// Jobs with batch_index -1 (legacy) are treated as having the highest priority among themselves +// but after jobs with explicit batch indices. +// +// Scans up to BatchIndexScanLimit items for performance. If the item was removed by another +// worker between find and remove, retries automatically. +func (rq *RedisQueue) dequeueLowestBatchIndex(queueName string) (*ProofJob, error) { + // Scan up to BatchIndexScanLimit items instead of the entire queue + items, err := rq.Client.LRange(rq.Ctx, queueName, 0, BatchIndexScanLimit-1).Result() + if err != nil { + return nil, err + } + + if len(items) == 0 { + return nil, redis.Nil + } + + if len(items) == 1 { + result, err := rq.Client.LPop(rq.Ctx, queueName).Result() + if err != nil { + return nil, err + } + var job ProofJob + if err := json.Unmarshal([]byte(result), &job); err != nil { + return nil, err + } + return &job, nil + } + + var lowestJob *ProofJob + lowestIdx := -1 + lowestBatchIndex := int64(^uint64(0) >> 1) + + for i, item := range items { + var job ProofJob + if err := json.Unmarshal([]byte(item), &job); err != nil { + logging.Logger().Warn(). + Err(err). + Str("queue", queueName). + Int("index", i). + Msg("Failed to unmarshal job while searching for lowest batch_index") + continue + } + + // Jobs with batch_index >= 0 have priority over legacy jobs (batch_index -1) + // Among jobs with batch_index >= 0, lower index wins + // Among legacy jobs, first in queue wins (FIFO) + if job.BatchIndex >= 0 { + if lowestJob == nil || lowestJob.BatchIndex < 0 || job.BatchIndex < lowestBatchIndex { + lowestJob = &job + lowestIdx = i + lowestBatchIndex = job.BatchIndex + } + } else if lowestJob == nil || (lowestJob.BatchIndex < 0 && lowestIdx > i) { + // Legacy job, only take if no better candidate or this is earlier in queue + lowestJob = &job + lowestIdx = i + lowestBatchIndex = job.BatchIndex + } + } + + if lowestJob == nil { + return nil, redis.Nil + } + + // Remove the selected job from the queue + itemToRemove := items[lowestIdx] + removed, err := rq.Client.LRem(rq.Ctx, queueName, 1, itemToRemove).Result() + if err != nil { + return nil, fmt.Errorf("failed to remove job from queue: %w", err) + } + + if removed == 0 { + // Item was already removed by another worker, retry + logging.Logger().Debug(). + Str("job_id", lowestJob.ID). + Str("queue", queueName). + Msg("Job was already removed from queue, retrying") + return rq.dequeueLowestBatchIndex(queueName) + } + + logging.Logger().Debug(). + Str("job_id", lowestJob.ID). + Int64("batch_index", lowestJob.BatchIndex). + Int("queue_position", lowestIdx). + Int("scanned", len(items)). + Str("queue", queueName). + Msg("Dequeued job with lowest batch_index") + + return lowestJob, nil +} + func (rq *RedisQueue) GetQueueStats() (map[string]int64, error) { stats := make(map[string]int64) @@ -87,6 +398,22 @@ func (rq *RedisQueue) GetQueueStats() (map[string]int64, error) { length = 0 } stats[queue] = length + + // For fair-queued queues, also count tree sub-queues + if isFairQueueEnabled(queue) { + treesSetKey := fmt.Sprintf("%s:trees", queue) + trees, err := rq.Client.SMembers(rq.Ctx, treesSetKey).Result() + if err == nil { + var totalTreeQueueLen int64 + for _, tree := range trees { + subQueueName := fmt.Sprintf("%s:%s", queue, tree) + subLen, _ := rq.Client.LLen(rq.Ctx, subQueueName).Result() + totalTreeQueueLen += subLen + } + stats[queue+"_tree_subqueues"] = totalTreeQueueLen + stats[queue+"_tree_count"] = int64(len(trees)) + } + } } return stats, nil @@ -155,8 +482,8 @@ func (rq *RedisQueue) GetResult(jobID string) (interface{}, error) { key := fmt.Sprintf("zk_result_%s", jobID) result, err := rq.Client.Get(rq.Ctx, key).Result() if err == nil { - var proof common.Proof - err = json.Unmarshal([]byte(result), &proof) + var proofWithTiming common.ProofWithTiming + err = json.Unmarshal([]byte(result), &proofWithTiming) if err != nil { logging.Logger().Error(). Str("job_id", jobID). @@ -166,7 +493,7 @@ func (rq *RedisQueue) GetResult(jobID string) (interface{}, error) { return nil, fmt.Errorf("failed to unmarshal direct result: %w", err) } - return &proof, nil + return &proofWithTiming, nil } if err != redis.Nil { @@ -186,14 +513,14 @@ func (rq *RedisQueue) searchResultInQueue(jobID string) (interface{}, error) { var resultJob ProofJob if json.Unmarshal([]byte(item), &resultJob) == nil { if resultJob.ID == jobID && resultJob.Type == "result" { - var proof common.Proof - err = json.Unmarshal(resultJob.Payload, &proof) + var proofWithTiming common.ProofWithTiming + err = json.Unmarshal(resultJob.Payload, &proofWithTiming) if err != nil { return nil, fmt.Errorf("failed to unmarshal queued result: %w", err) } - rq.StoreResult(jobID, &proof) + rq.StoreResult(jobID, &proofWithTiming) - return &proof, nil + return &proofWithTiming, nil } } } @@ -225,6 +552,44 @@ func (rq *RedisQueue) StoreResult(jobID string, result interface{}) error { return nil } +// IndexResultByHash atomically adds inputHash → jobID to the results index hash. +func (rq *RedisQueue) IndexResultByHash(inputHash, jobID string) error { + err := rq.Client.HSet(rq.Ctx, ResultsIndexKey, inputHash, jobID).Err() + if err != nil { + return fmt.Errorf("failed to index result: %w", err) + } + logging.Logger().Debug(). + Str("input_hash", inputHash). + Str("job_id", jobID). + Msg("Indexed result by input hash") + return nil +} + +// IndexFailureByHash atomically adds inputHash → jobID to the failed index hash. +func (rq *RedisQueue) IndexFailureByHash(inputHash, jobID string) error { + err := rq.Client.HSet(rq.Ctx, FailedIndexKey, inputHash, jobID).Err() + if err != nil { + return fmt.Errorf("failed to index failure: %w", err) + } + logging.Logger().Debug(). + Str("input_hash", inputHash). + Str("job_id", jobID). + Msg("Indexed failure by input hash") + return nil +} + +// RemoveResultIndex removes inputHash from the results index hash. +// Called during cleanup to keep the index in sync with the queue. +func (rq *RedisQueue) RemoveResultIndex(inputHash string) error { + return rq.Client.HDel(rq.Ctx, ResultsIndexKey, inputHash).Err() +} + +// RemoveFailureIndex removes inputHash from the failed index hash. +// Called during cleanup to keep the index in sync with the queue. +func (rq *RedisQueue) RemoveFailureIndex(inputHash string) error { + return rq.Client.HDel(rq.Ctx, FailedIndexKey, inputHash).Err() +} + func (rq *RedisQueue) CleanupOldResults() error { // Remove successful results older than 1 hour cutoffTime := time.Now().Add(-1 * time.Hour) @@ -234,6 +599,7 @@ func (rq *RedisQueue) CleanupOldResults() error { logging.Logger().Error(). Err(err). Msg("Failed to cleanup old results by time") + return err } if removed > 0 { @@ -243,23 +609,6 @@ func (rq *RedisQueue) CleanupOldResults() error { Msg("Cleaned up old results by time") } - ctx := context.Background() - length, err := rq.Client.LLen(ctx, "zk_results_queue").Result() - if err != nil { - return err - } - - if length > 1000 { - toRemove := length - 1000 - for i := int64(0); i < toRemove; i++ { - rq.Client.LPop(ctx, "zk_results_queue") - } - - logging.Logger().Info(). - Int64("removed_items", toRemove). - Msg("Cleaned up old results from queue (length-based safety)") - } - return nil } @@ -276,15 +625,42 @@ func (rq *RedisQueue) CleanupOldRequests() error { totalRemoved := int64(0) for _, queueName := range queuesToClean { + // Clean main queue removed, err := rq.cleanupOldRequestsFromQueue(queueName, cutoffTime) if err != nil { logging.Logger().Error(). Err(err). Str("queue", queueName). Msg("Failed to cleanup old requests from queue") - continue + } else { + totalRemoved += removed + } + + // Clean tree sub-queues for fair-queued queues + if isFairQueueEnabled(queueName) { + treesSetKey := fmt.Sprintf("%s:trees", queueName) + trees, err := rq.Client.SMembers(rq.Ctx, treesSetKey).Result() + if err == nil { + for _, tree := range trees { + subQueueName := fmt.Sprintf("%s:%s", queueName, tree) + subRemoved, err := rq.cleanupOldRequestsFromQueue(subQueueName, cutoffTime) + if err != nil { + logging.Logger().Error(). + Err(err). + Str("queue", subQueueName). + Msg("Failed to cleanup old requests from tree sub-queue") + continue + } + totalRemoved += subRemoved + + // If tree queue is now empty, remove from trees set + queueLen, _ := rq.Client.LLen(rq.Ctx, subQueueName).Result() + if queueLen == 0 { + rq.Client.SRem(rq.Ctx, treesSetKey, tree) + } + } + } } - totalRemoved += removed } if totalRemoved > 0 { @@ -340,8 +716,9 @@ func (rq *RedisQueue) CleanupOldResultKeys() error { } func (rq *RedisQueue) CleanupStuckProcessingJobs() error { - // Jobs stuck in processing for more than 2 minutes are considered stuck - processingTimeout := time.Now().Add(-2 * time.Minute) + // Jobs stuck in processing for more than 10 minutes are considered stuck + // (proof generation can take 3-4 minutes under load) + processingTimeout := time.Now().Add(-10 * time.Minute) processingQueues := []string{ "zk_update_processing_queue", @@ -428,12 +805,23 @@ func (rq *RedisQueue) recoverStuckJobsFromQueue(queueName string, timeoutCutoff fiveMinutesAgo := time.Now().Add(-5 * time.Minute) if job.CreatedAt.Before(fiveMinutesAgo) { + // Extract circuit type from payload for debugging, but don't store full payload + // to prevent memory issues (payloads can be hundreds of KB) + var circuitType string + var payloadMeta map[string]interface{} + if json.Unmarshal(job.Payload, &payloadMeta) == nil { + if ct, ok := payloadMeta["circuitType"].(string); ok { + circuitType = ct + } + } + failureDetails := map[string]interface{}{ "original_job": map[string]interface{}{ - "id": originalJobID, - "type": "zk_proof", - "payload": job.Payload, - "created_at": job.CreatedAt, + "id": originalJobID, + "type": "zk_proof", + "circuit_type": circuitType, + "payload_size": len(job.Payload), + "created_at": job.CreatedAt, }, "error": "Job timed out in processing queue (stuck for >5 minutes)", "failed_at": time.Now(), @@ -535,6 +923,10 @@ func (rq *RedisQueue) cleanupOldRequestsFromQueue(queueName string, cutoffTime t } if count > 0 { removedCount++ + + // Also clean up the index entry if this is a results/failed queue + rq.cleanupIndexEntry(queueName, job.ID) + logging.Logger().Debug(). Str("job_id", job.ID). Str("queue", queueName). @@ -547,3 +939,430 @@ func (rq *RedisQueue) cleanupOldRequestsFromQueue(queueName string, cutoffTime t return removedCount, nil } + +// cleanupIndexEntry removes the hash index entry for a job being cleaned up. +// Looks up the inputHash from zk_input_hash_{jobID} and removes from the appropriate index. +func (rq *RedisQueue) cleanupIndexEntry(queueName string, jobID string) { + // Extract original job ID (remove _failed suffix if present) + originalJobID := jobID + if len(jobID) > 7 && jobID[len(jobID)-7:] == "_failed" { + originalJobID = jobID[:len(jobID)-7] + } + + // Look up the input hash for this job + inputHash, err := rq.Client.Get(rq.Ctx, fmt.Sprintf("zk_input_hash_%s", originalJobID)).Result() + if err != nil { + // Input hash not found or expired - nothing to clean up + return + } + + // Remove from the appropriate index based on queue type + switch queueName { + case "zk_results_queue": + rq.RemoveResultIndex(inputHash) + case "zk_failed_queue": + rq.RemoveFailureIndex(inputHash) + } +} + +// ComputeInputHash computes a SHA256 hash of the proof input payload +func ComputeInputHash(payload json.RawMessage) string { + hash := sha256.Sum256(payload) + return hex.EncodeToString(hash[:]) +} + +// FindCachedResult searches for a cached result by input hash. +// Returns the proof result (as ProofWithTiming) and job ID if found, otherwise returns nil. +func (rq *RedisQueue) FindCachedResult(inputHash string) (*common.ProofWithTiming, string, error) { + jobID, err := rq.Client.HGet(rq.Ctx, ResultsIndexKey, inputHash).Result() + if err == nil && jobID != "" { + result, fetchErr := rq.Client.Get(rq.Ctx, fmt.Sprintf("zk_result_%s", jobID)).Result() + if fetchErr == nil { + var proofWithTiming common.ProofWithTiming + if json.Unmarshal([]byte(result), &proofWithTiming) == nil { + logging.Logger().Info(). + Str("input_hash", inputHash). + Str("cached_job_id", jobID). + Int64("proof_duration_ms", proofWithTiming.ProofDurationMs). + Msg("Found cached successful proof result via index") + return &proofWithTiming, jobID, nil + } + } + // Index entry exists but result is missing/invalid - clean up stale index entry + logging.Logger().Debug(). + Str("input_hash", inputHash). + Str("job_id", jobID). + Msg("Stale index entry, removing and falling back to queue scan") + rq.RemoveResultIndex(inputHash) + } else if err != nil && err != redis.Nil { + logging.Logger().Warn(). + Err(err). + Str("input_hash", inputHash). + Msg("Error querying results index, falling back to queue scan") + } + + // Fallback: O(n) queue scan for backward compatibility with unindexed results + items, err := rq.Client.LRange(rq.Ctx, "zk_results_queue", 0, -1).Result() + if err != nil { + return nil, "", fmt.Errorf("failed to search results queue: %w", err) + } + + for _, item := range items { + var resultJob ProofJob + if json.Unmarshal([]byte(item), &resultJob) == nil && resultJob.Type == "result" { + // Check if this result has the same input hash + storedHash, err := rq.Client.Get(rq.Ctx, fmt.Sprintf("zk_input_hash_%s", resultJob.ID)).Result() + if err == nil && storedHash == inputHash { + var proofWithTiming common.ProofWithTiming + err = json.Unmarshal(resultJob.Payload, &proofWithTiming) + if err != nil { + logging.Logger().Warn(). + Err(err). + Str("input_hash", inputHash). + Str("job_id", resultJob.ID). + Msg("Failed to unmarshal cached result payload, skipping") + continue + } + + logging.Logger().Info(). + Str("input_hash", inputHash). + Str("cached_job_id", resultJob.ID). + Int64("proof_duration_ms", proofWithTiming.ProofDurationMs). + Msg("Found cached successful proof result via queue scan") + + rq.IndexResultByHash(inputHash, resultJob.ID) + + return &proofWithTiming, resultJob.ID, nil + } + } + } + + return nil, "", nil +} + +// FindCachedFailure searches for a cached failure by input hash. +// Returns the failure details and job ID if found, otherwise returns nil. +func (rq *RedisQueue) FindCachedFailure(inputHash string) (map[string]interface{}, string, error) { + jobID, err := rq.Client.HGet(rq.Ctx, FailedIndexKey, inputHash).Result() + if err == nil && jobID != "" { + // Found in index, search for the job in failed queue by ID + items, err := rq.Client.LRange(rq.Ctx, "zk_failed_queue", 0, -1).Result() + if err == nil { + failedJobID := jobID + "_failed" + for _, item := range items { + var failedJob ProofJob + if json.Unmarshal([]byte(item), &failedJob) == nil && failedJob.ID == failedJobID { + var failureDetails map[string]interface{} + if json.Unmarshal(failedJob.Payload, &failureDetails) == nil { + logging.Logger().Info(). + Str("input_hash", inputHash). + Str("cached_job_id", jobID). + Msg("Found cached failed proof result via index") + return failureDetails, jobID, nil + } + } + } + } + // Index entry exists but failure job not found - clean up stale index entry + logging.Logger().Debug(). + Str("input_hash", inputHash). + Str("job_id", jobID). + Msg("Stale failure index entry, removing and falling back to queue scan") + rq.RemoveFailureIndex(inputHash) + } else if err != nil && err != redis.Nil { + logging.Logger().Warn(). + Err(err). + Str("input_hash", inputHash). + Msg("Error querying failed index, falling back to queue scan") + } + + // Fallback: O(n) queue scan for backward compatibility with unindexed failures + items, err := rq.Client.LRange(rq.Ctx, "zk_failed_queue", 0, -1).Result() + if err != nil { + return nil, "", fmt.Errorf("failed to search failed queue: %w", err) + } + + for _, item := range items { + var failedJob ProofJob + if json.Unmarshal([]byte(item), &failedJob) == nil && failedJob.Type == "failed" { + // Extract the original job ID (remove _failed suffix) + originalJobID := failedJob.ID + if len(failedJob.ID) > 7 && failedJob.ID[len(failedJob.ID)-7:] == "_failed" { + originalJobID = failedJob.ID[:len(failedJob.ID)-7] + } + + // Check if this failure has the same input hash + storedHash, err := rq.Client.Get(rq.Ctx, fmt.Sprintf("zk_input_hash_%s", originalJobID)).Result() + if err == nil && storedHash == inputHash { + // Found a matching failure + var failureDetails map[string]interface{} + err = json.Unmarshal(failedJob.Payload, &failureDetails) + if err != nil { + continue + } + + logging.Logger().Info(). + Str("input_hash", inputHash). + Str("cached_job_id", originalJobID). + Msg("Found cached failed proof result via queue scan") + + // Backfill the index for future O(1) lookups + rq.IndexFailureByHash(inputHash, originalJobID) + + return failureDetails, originalJobID, nil + } + } + } + + return nil, "", nil +} + +// StoreInputHash stores the input hash for a job ID to enable deduplication +func (rq *RedisQueue) StoreInputHash(jobID string, inputHash string) error { + key := fmt.Sprintf("zk_input_hash_%s", jobID) + err := rq.Client.Set(rq.Ctx, key, inputHash, 1*time.Hour).Err() + if err != nil { + return fmt.Errorf("failed to store input hash: %w", err) + } + + logging.Logger().Debug(). + Str("job_id", jobID). + Str("input_hash", inputHash). + Msg("Stored input hash for deduplication") + + return nil +} + +// GetOrSetInFlightJob atomically checks if a job with the given input hash is already in-flight. +// If not, it registers the new job ID. Returns the existing job ID if found, or the new job ID if set. +// The isNew return value indicates whether this is a new job (true) or an existing one (false). +// TTL is set to 10 minutes to match the forester's max wait time. +func (rq *RedisQueue) GetOrSetInFlightJob(inputHash, jobID string) (existingJobID string, isNew bool, err error) { + key := fmt.Sprintf("zk_inflight_%s", inputHash) + + // Try to set the key atomically - only succeeds if key doesn't exist + set, err := rq.Client.SetNX(rq.Ctx, key, jobID, 10*time.Minute).Result() + if err != nil { + return "", false, fmt.Errorf("failed to check/set in-flight job: %w", err) + } + + if set { + // Key was set - this is a new job + // Also store reverse mapping so we can find the input hash from job ID + // This is needed for CleanupStaleInFlightMarker when job_not_found + reverseKey := fmt.Sprintf("zk_input_hash_%s", jobID) + rq.Client.Set(rq.Ctx, reverseKey, inputHash, 10*time.Minute) + + logging.Logger().Debug(). + Str("job_id", jobID). + Str("input_hash", inputHash). + Msg("Registered new in-flight job") + return jobID, true, nil + } + + // Key already exists - get the existing job ID + existing, err := rq.Client.Get(rq.Ctx, key).Result() + if err != nil { + // Key might have expired between SetNX and Get - retry + if err == redis.Nil { + ok, err := rq.Client.SetNX(rq.Ctx, key, jobID, 10*time.Minute).Result() + if err != nil { + return "", false, fmt.Errorf("failed to set in-flight job on retry: %w", err) + } + if !ok { + // Another worker won the race - fetch their job ID + existing, err := rq.Client.Get(rq.Ctx, key).Result() + if err != nil { + return "", false, fmt.Errorf("failed to get winning job after retry race: %w", err) + } + return existing, false, nil + } + // We won the retry - store reverse mapping for cleanup + reverseKey := fmt.Sprintf("zk_input_hash_%s", jobID) + rq.Client.Set(rq.Ctx, reverseKey, inputHash, 10*time.Minute) + return jobID, true, nil + } + return "", false, fmt.Errorf("failed to get existing in-flight job: %w", err) + } + + logging.Logger().Info(). + Str("existing_job_id", existing). + Str("input_hash", inputHash). + Msg("Found existing in-flight job with same input") + + return existing, false, nil +} + +// DeleteInFlightJob removes the in-flight marker for a job when it completes. +// This should be called when a job finishes (success or failure) to allow +// new jobs with the same input to be queued. +func (rq *RedisQueue) DeleteInFlightJob(inputHash, jobID string) error { + key := fmt.Sprintf("zk_inflight_%s", inputHash) + err := rq.Client.Del(rq.Ctx, key).Err() + if err != nil { + return fmt.Errorf("failed to delete in-flight job marker: %w", err) + } + + // Also clean up the reverse mapping + reverseKey := fmt.Sprintf("zk_input_hash_%s", jobID) + rq.Client.Del(rq.Ctx, reverseKey) + + logging.Logger().Debug(). + Str("input_hash", inputHash). + Str("job_id", jobID). + Msg("Deleted in-flight job marker") + + return nil +} + +// SetInFlightJob sets the in-flight marker for a job, replacing any existing marker. +// This is used when recovering from a stale marker to register a new job. +// Also sets the reverse mapping (jobID → inputHash) for cleanup. +func (rq *RedisQueue) SetInFlightJob(inputHash, jobID string, ttl time.Duration) error { + key := fmt.Sprintf("zk_inflight_%s", inputHash) + err := rq.Client.Set(rq.Ctx, key, jobID, ttl).Err() + if err != nil { + return fmt.Errorf("failed to set in-flight job marker: %w", err) + } + + // Also store reverse mapping so we can find the input hash from job ID + reverseKey := fmt.Sprintf("zk_input_hash_%s", jobID) + if reverseErr := rq.Client.Set(rq.Ctx, reverseKey, inputHash, ttl).Err(); reverseErr != nil { + logging.Logger().Warn(). + Err(reverseErr). + Str("job_id", jobID). + Msg("Failed to set reverse mapping (non-critical)") + } + + logging.Logger().Debug(). + Str("input_hash", inputHash). + Str("job_id", jobID). + Dur("ttl", ttl). + Msg("Set in-flight job marker") + + return nil +} + +// CleanupStaleInFlightMarker removes a stale in-flight marker for a job that no longer exists. +// This is called when a status check returns job_not_found, indicating the job was lost +// (e.g., due to prover restart) but the in-flight marker still exists. +// This allows new requests with the same input to create a new job instead of being +// deduplicated to the stale job ID. +func (rq *RedisQueue) CleanupStaleInFlightMarker(jobID string) { + // Get the input hash associated with this job ID + inputHashKey := fmt.Sprintf("zk_input_hash_%s", jobID) + inputHash, err := rq.Client.Get(rq.Ctx, inputHashKey).Result() + if err != nil { + // No input hash found - nothing to clean up + return + } + + // Check if the in-flight marker points to this job ID + inFlightKey := fmt.Sprintf("zk_inflight_%s", inputHash) + storedJobID, err := rq.Client.Get(rq.Ctx, inFlightKey).Result() + if err != nil { + // No in-flight marker - nothing to clean up + return + } + + // Only delete if this marker points to the stale job + if storedJobID == jobID { + rq.Client.Del(rq.Ctx, inFlightKey) + logging.Logger().Info(). + Str("job_id", jobID). + Str("input_hash", inputHash). + Msg("Cleaned up stale in-flight marker for lost job") + } + + // Also clean up the input hash mapping + rq.Client.Del(rq.Ctx, inputHashKey) +} + +// DeduplicationResult contains the result of a job deduplication check. +type DeduplicationResult struct { + // JobID is the resolved job ID to use (either new or existing). + JobID string + // IsNew indicates this is a new job that needs to be enqueued. + IsNew bool + // IsDeduplicated indicates the request was deduplicated to an existing job. + IsDeduplicated bool + // StaleJobID is set when a stale job was found and cleaned up. + StaleJobID string +} + +// DeduplicateJob checks for an existing in-flight job with the same input hash. +// If an existing job is found and still valid (has result or metadata), it returns +// that job's ID with IsDeduplicated=true. If an existing marker points to a stale +// job (no result/metadata), it cleans up the stale marker and creates a new job. +// Returns the resolved job ID and flags indicating the deduplication outcome. +// +// The TTL for in-flight markers is 10 minutes to match the forester's max wait time. +func (rq *RedisQueue) DeduplicateJob(inputHash string) (*DeduplicationResult, error) { + // Generate a new job ID + newJobID := uuid.New().String() + + // Try to atomically set our job as in-flight + existingJobID, isNew, err := rq.GetOrSetInFlightJob(inputHash, newJobID) + if err != nil { + logging.Logger().Warn(). + Err(err). + Str("input_hash", inputHash). + Msg("Failed to check for in-flight job, proceeding with new job") + // On error, proceed with the new job + return &DeduplicationResult{ + JobID: newJobID, + IsNew: true, + }, nil + } + + // If we successfully set a new job, we're done + if isNew { + return &DeduplicationResult{ + JobID: existingJobID, // This is our newJobID + IsNew: true, + }, nil + } + + // An existing job was found - verify it actually exists + jobExists := false + if result, _ := rq.GetResult(existingJobID); result != nil { + jobExists = true + } else if jobMeta, _ := rq.GetJobMeta(existingJobID); jobMeta != nil { + jobExists = true + } + + if jobExists { + // Valid existing job found - deduplicate to it + logging.Logger().Info(). + Str("existing_job_id", existingJobID). + Str("input_hash", inputHash). + Msg("Deduplicated proof request to existing job") + + return &DeduplicationResult{ + JobID: existingJobID, + IsNew: false, + IsDeduplicated: true, + }, nil + } + + // Job doesn't exist - stale marker found + logging.Logger().Warn(). + Str("stale_job_id", existingJobID). + Str("input_hash", inputHash). + Msg("Deduplication found stale job ID - cleaning up and creating new job") + + // Clean up the stale marker + rq.CleanupStaleInFlightMarker(existingJobID) + + // Generate a fresh job ID and set new in-flight marker + freshJobID := uuid.New().String() + if err := rq.SetInFlightJob(inputHash, freshJobID, 10*time.Minute); err != nil { + return nil, fmt.Errorf("failed to set in-flight marker after stale cleanup: %w", err) + } + + return &DeduplicationResult{ + JobID: freshJobID, + IsNew: true, + StaleJobID: existingJobID, + }, nil +} diff --git a/prover/server/server/queue_job.go b/prover/server/server/queue_job.go index e8d3074570..9a225290fe 100644 --- a/prover/server/server/queue_job.go +++ b/prover/server/server/queue_job.go @@ -5,22 +5,109 @@ import ( "fmt" "light/light-prover/logging" "light/light-prover/prover/common" - "light/light-prover/prover/v1" - "light/light-prover/prover/v2" + v1 "light/light-prover/prover/v1" + v2 "light/light-prover/prover/v2" "log" + "os" + "strconv" "time" ) const ( // JobExpirationTimeout should match the forester's max_wait_time (600 seconds) JobExpirationTimeout = 600 * time.Second + + // Memory estimates per circuit type (in GB) + // Based on live measurements: ~11GB per batch-500 proof + // batch_update_32_500: ~11GB (8M constraints) + // batch_append_32_500: ~11GB (7.8M constraints) + // batch_address-append_40_250: ~15GB (larger tree height) + // + // For safety, we use the largest (address-append) as the baseline + MemoryPerProofGB = 15 + + // MemoryReserveGB is memory to reserve for OS, proving keys, and other processes + // Proving keys can be 10-20GB depending on which circuits are loaded + MemoryReserveGB = 20 + + // NumQueueWorkers is the number of queue workers (update, append, address-append) + NumQueueWorkers = 3 + + // MinConcurrencyPerWorker is the minimum concurrency per worker + MinConcurrencyPerWorker = 1 + + // MaxConcurrencyPerWorker is the maximum concurrency per worker (safety cap) + MaxConcurrencyPerWorker = 100 ) +// getMaxConcurrency returns the max concurrency per worker. +// Configuration priority: +// 1. PROVER_MAX_CONCURRENCY env var +// 2. Auto-calculate from PROVER_TOTAL_MEMORY_GB env var +// 3. Default to MinConcurrencyPerWorker +func getMaxConcurrency() int { + // Check for explicit concurrency override + if val := os.Getenv("PROVER_MAX_CONCURRENCY"); val != "" { + if concurrency, err := strconv.Atoi(val); err == nil && concurrency > 0 { + logging.Logger().Info(). + Int("max_concurrency", concurrency). + Msg("Using PROVER_MAX_CONCURRENCY") + return concurrency + } + } + + // Check for memory-based configuration + if val := os.Getenv("PROVER_TOTAL_MEMORY_GB"); val != "" { + if totalMemGB, err := strconv.Atoi(val); err == nil && totalMemGB > 0 { + concurrency := calculateConcurrency(totalMemGB) + logging.Logger().Info(). + Int("total_memory_gb", totalMemGB). + Int("max_concurrency", concurrency). + Msg("Calculated concurrency from PROVER_TOTAL_MEMORY_GB") + return concurrency + } + } + + // Default fallback + logging.Logger().Info(). + Int("max_concurrency", MinConcurrencyPerWorker). + Msg("Using default min concurrency (set PROVER_MAX_CONCURRENCY or PROVER_TOTAL_MEMORY_GB to configure)") + return MinConcurrencyPerWorker +} + +// calculateConcurrency computes per-worker concurrency from total memory. +// Formula: (TotalRAM - Reserve) / (MemoryPerProof * NumWorkers) +func calculateConcurrency(totalMemGB int) int { + availableMemGB := totalMemGB - MemoryReserveGB + if availableMemGB < MemoryPerProofGB { + return MinConcurrencyPerWorker + } + + totalConcurrentProofs := availableMemGB / MemoryPerProofGB + perWorkerConcurrency := totalConcurrentProofs / NumQueueWorkers + + if perWorkerConcurrency < MinConcurrencyPerWorker { + return MinConcurrencyPerWorker + } + if perWorkerConcurrency > MaxConcurrencyPerWorker { + return MaxConcurrencyPerWorker + } + + return perWorkerConcurrency +} + type ProofJob struct { ID string `json:"id"` Type string `json:"type"` Payload json.RawMessage `json:"payload"` CreatedAt time.Time `json:"created_at"` + // TreeID is the merkle tree pubkey - used for fair queuing across trees + // If empty, job goes to the default queue (backwards compatible) + TreeID string `json:"tree_id,omitempty"` + // BatchIndex is the batch sequence number within a tree - used to process batches in order + // Lower batch indices should be processed first to enable sequential transaction submission + // -1 means no batch index (legacy requests, FIFO) + BatchIndex int64 `json:"batch_index"` } type QueueWorker interface { @@ -34,6 +121,8 @@ type BaseQueueWorker struct { stopChan chan struct{} queueName string processingQueueName string + maxConcurrency int + semaphore chan struct{} } type UpdateQueueWorker struct { @@ -49,6 +138,7 @@ type AddressAppendQueueWorker struct { } func NewUpdateQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyManager) *UpdateQueueWorker { + maxConcurrency := getMaxConcurrency() return &UpdateQueueWorker{ BaseQueueWorker: &BaseQueueWorker{ queue: redisQueue, @@ -56,11 +146,14 @@ func NewUpdateQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyMana stopChan: make(chan struct{}), queueName: "zk_update_queue", processingQueueName: "zk_update_processing_queue", + maxConcurrency: maxConcurrency, + semaphore: make(chan struct{}, maxConcurrency), }, } } func NewAppendQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyManager) *AppendQueueWorker { + maxConcurrency := getMaxConcurrency() return &AppendQueueWorker{ BaseQueueWorker: &BaseQueueWorker{ queue: redisQueue, @@ -68,11 +161,14 @@ func NewAppendQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyMana stopChan: make(chan struct{}), queueName: "zk_append_queue", processingQueueName: "zk_append_processing_queue", + maxConcurrency: maxConcurrency, + semaphore: make(chan struct{}, maxConcurrency), }, } } func NewAddressAppendQueueWorker(redisQueue *RedisQueue, keyManager *common.LazyKeyManager) *AddressAppendQueueWorker { + maxConcurrency := getMaxConcurrency() return &AddressAppendQueueWorker{ BaseQueueWorker: &BaseQueueWorker{ queue: redisQueue, @@ -80,12 +176,17 @@ func NewAddressAppendQueueWorker(redisQueue *RedisQueue, keyManager *common.Lazy stopChan: make(chan struct{}), queueName: "zk_address_append_queue", processingQueueName: "zk_address_append_processing_queue", + maxConcurrency: maxConcurrency, + semaphore: make(chan struct{}, maxConcurrency), }, } } func (w *BaseQueueWorker) Start() { - logging.Logger().Info().Str("queue", w.queueName).Msg("Starting queue worker") + logging.Logger().Info(). + Str("queue", w.queueName). + Int("max_concurrency", w.maxConcurrency). + Msg("Starting queue worker with parallel processing") for { select { @@ -133,17 +234,19 @@ func (w *BaseQueueWorker) processJobs() { // Add to failed queue with expiration reason expirationErr := fmt.Errorf("job expired after %v (max: %v)", jobAge, JobExpirationTimeout) - w.addToFailedQueue(job, expirationErr) + expiredInputHash := ComputeInputHash(job.Payload) + w.addToFailedQueue(job, expiredInputHash, expirationErr) return } queueWaitTime := jobAge.Seconds() circuitType := "unknown" - if w.queueName == "zk_update_queue" { + switch w.queueName { + case "zk_update_queue": circuitType = "update" - } else if w.queueName == "zk_append_queue" { + case "zk_append_queue": circuitType = "append" - } else if w.queueName == "zk_address_append_queue" { + case "zk_address_append_queue": circuitType = "address-append" } QueueWaitTime.WithLabelValues(circuitType).Observe(queueWaitTime) @@ -153,31 +256,215 @@ func (w *BaseQueueWorker) processJobs() { Str("job_id", job.ID). Str("job_type", job.Type). Str("queue", w.queueName). - Msg("Processing proof job") + Msg("Dequeued proof job") - processingJob := &ProofJob{ - ID: job.ID + "_processing", - Type: "processing", - Payload: job.Payload, - CreatedAt: time.Now(), - } - err = w.queue.EnqueueProof(w.processingQueueName, processingJob) + // Check for duplicate inputs before processing + inputHash := ComputeInputHash(job.Payload) + + // Check if we already have a successful result for this input + cachedProof, cachedJobID, err := w.queue.FindCachedResult(inputHash) if err != nil { + logging.Logger().Warn(). + Err(err). + Str("job_id", job.ID). + Str("input_hash", inputHash). + Msg("Error searching for cached result, continuing with processing") + } else if cachedProof != nil { + // Found a cached successful result, return it immediately + logging.Logger().Info(). + Str("job_id", job.ID). + Str("cached_job_id", cachedJobID). + Str("input_hash", inputHash). + Msg("Returning cached successful proof result without re-processing") + + // Store result for new job ID + resultData, _ := json.Marshal(cachedProof) + resultJob := &ProofJob{ + ID: job.ID, + Type: "result", + Payload: json.RawMessage(resultData), + CreatedAt: time.Now(), + } + err = w.queue.EnqueueProof("zk_results_queue", resultJob) + if err != nil { + logging.Logger().Error().Err(err).Str("job_id", job.ID).Msg("Failed to enqueue cached result") + } + w.queue.StoreResult(job.ID, cachedProof) + w.queue.StoreInputHash(job.ID, inputHash) + w.queue.IndexResultByHash(inputHash, job.ID) return } - err = w.processProofJob(job) - w.removeFromProcessingQueue(job.ID) - + cachedFailure, cachedFailedJobID, err := w.queue.FindCachedFailure(inputHash) if err != nil { - logging.Logger().Error(). + logging.Logger().Warn(). Err(err). Str("job_id", job.ID). - Str("queue", w.queueName). - Msg("Failed to process proof job") + Str("input_hash", inputHash). + Msg("Error searching for cached failure, continuing with processing") + } else if cachedFailure != nil { + // Found a cached failure, return it immediately + logging.Logger().Info(). + Str("job_id", job.ID). + Str("cached_job_id", cachedFailedJobID). + Str("input_hash", inputHash). + Msg("Returning cached failure without re-processing") + + // Extract error message from cached failure + var errorMsg string + if errMsg, ok := cachedFailure["error"].(string); ok { + errorMsg = errMsg + } else { + errorMsg = "Proof generation failed (cached failure)" + } - w.addToFailedQueue(job, err) + // Add to failed queue with new job ID (without full payload to save memory) + failedJob := map[string]interface{}{ + "original_job": map[string]interface{}{ + "id": job.ID, + "type": job.Type, + "payload_size": len(job.Payload), + "created_at": job.CreatedAt, + }, + "error": errorMsg, + "failed_at": time.Now(), + "cached_from": cachedFailedJobID, + } + + failedData, _ := json.Marshal(failedJob) + failedJobStruct := &ProofJob{ + ID: job.ID + "_failed", + Type: "failed", + Payload: json.RawMessage(failedData), + CreatedAt: time.Now(), + } + + err = w.queue.EnqueueProof("zk_failed_queue", failedJobStruct) + if err != nil { + logging.Logger().Error().Err(err).Str("job_id", job.ID).Msg("Failed to enqueue cached failure") + } + w.queue.StoreInputHash(job.ID, inputHash) + w.queue.IndexFailureByHash(inputHash, job.ID) + return } + + // No cached result found, proceed with normal processing + // Store the input hash for this job to enable future deduplication + w.queue.StoreInputHash(job.ID, inputHash) + + w.semaphore <- struct{}{} + + go func(job *ProofJob, inputHash string) { + defer func() { + <-w.semaphore + }() + + proofStartTime := time.Now() + + logging.Logger().Info(). + Str("job_id", job.ID). + Str("queue", w.queueName). + Msg("Starting proof generation") + + processingJob := &ProofJob{ + ID: job.ID + "_processing", + Type: "processing", + Payload: job.Payload, + CreatedAt: time.Now(), + } + err := w.queue.EnqueueProof(w.processingQueueName, processingJob) + if err != nil { + logging.Logger().Error(). + Err(err). + Str("job_id", job.ID). + Str("processing_queue", w.processingQueueName). + Msg("Failed to add job to processing queue") + return + } + + proof, err := w.generateProof(job) + w.removeFromProcessingQueue(job.ID) + + proofDuration := time.Since(proofStartTime) + + if err != nil { + logging.Logger().Error(). + Err(err). + Str("job_id", job.ID). + Str("queue", w.queueName). + Dur("duration", proofDuration). + Msg("Failed to process proof job") + + w.addToFailedQueue(job, inputHash, err) + + // On failure: clean up in-flight marker to allow retry with new job + if delErr := w.queue.DeleteInFlightJob(inputHash, job.ID); delErr != nil { + logging.Logger().Warn(). + Err(delErr). + Str("job_id", job.ID). + Str("input_hash", inputHash). + Msg("Failed to delete in-flight job marker (non-critical)") + } + // Clean up job metadata + if delErr := w.queue.DeleteJobMeta(job.ID); delErr != nil { + logging.Logger().Warn(). + Err(delErr). + Str("job_id", job.ID). + Msg("Failed to delete job metadata (non-critical)") + } + } else { + // Store result with timing information + proofWithTiming := &common.ProofWithTiming{ + Proof: proof, + ProofDurationMs: proofDuration.Milliseconds(), + } + + resultData, _ := json.Marshal(proofWithTiming) + resultJob := &ProofJob{ + ID: job.ID, + Type: "result", + Payload: json.RawMessage(resultData), + CreatedAt: time.Now(), + } + if enqueueErr := w.queue.EnqueueProof("zk_results_queue", resultJob); enqueueErr != nil { + logging.Logger().Error(). + Err(enqueueErr). + Str("job_id", job.ID). + Msg("Failed to enqueue result") + } + if storeErr := w.queue.StoreResult(job.ID, proofWithTiming); storeErr != nil { + logging.Logger().Error(). + Err(storeErr). + Str("job_id", job.ID). + Msg("Failed to store result") + } + + if indexErr := w.queue.IndexResultByHash(inputHash, job.ID); indexErr != nil { + logging.Logger().Warn(). + Err(indexErr). + Str("job_id", job.ID). + Msg("Failed to index result (non-critical)") + } + + logging.Logger().Info(). + Str("job_id", job.ID). + Str("queue", w.queueName). + Dur("duration", proofDuration). + Int64("duration_ms", proofDuration.Milliseconds()). + Msg("Proof job completed successfully") + + // On success: DON'T delete in-flight marker - let it expire with the result. + // This allows future requests with identical inputs to get the cached result + // instead of creating a new job. Both marker and result have 10-min TTL. + // Only clean up job metadata (no longer needed since result is stored). + if delErr := w.queue.DeleteJobMeta(job.ID); delErr != nil { + logging.Logger().Warn(). + Err(delErr). + Str("job_id", job.ID). + Msg("Failed to delete job metadata (non-critical)") + } + } + }(job, inputHash) } func (w *UpdateQueueWorker) Start() { @@ -204,10 +491,12 @@ func (w *AddressAppendQueueWorker) Stop() { w.BaseQueueWorker.Stop() } -func (w *BaseQueueWorker) processProofJob(job *ProofJob) error { +// generateProof generates a proof for the given job and returns it. +// Result storage is handled by the caller to include timing information. +func (w *BaseQueueWorker) generateProof(job *ProofJob) (*common.Proof, error) { proofRequestMeta, err := common.ParseProofRequestMeta(job.Payload) if err != nil { - return fmt.Errorf("failed to parse proof request: %w", err) + return nil, fmt.Errorf("failed to parse proof request: %w", err) } timer := StartProofTimer(string(proofRequestMeta.CircuitType)) @@ -232,13 +521,13 @@ func (w *BaseQueueWorker) processProofJob(job *ProofJob) error { case common.BatchAddressAppendCircuitType: proof, proofError = w.processBatchAddressAppendProof(job.Payload) default: - return fmt.Errorf("unknown circuit type: %s", proofRequestMeta.CircuitType) + return nil, fmt.Errorf("unknown circuit type: %s", proofRequestMeta.CircuitType) } if proofError != nil { timer.ObserveError("proof_generation_failed") RecordJobComplete(false) - return proofError + return nil, proofError } timer.ObserveDuration() @@ -249,18 +538,7 @@ func (w *BaseQueueWorker) processProofJob(job *ProofJob) error { RecordProofSize(string(proofRequestMeta.CircuitType), len(proofBytes)) } - resultData, _ := json.Marshal(proof) - resultJob := &ProofJob{ - ID: job.ID, - Type: "result", - Payload: json.RawMessage(resultData), - CreatedAt: time.Now(), - } - err = w.queue.EnqueueProof("zk_results_queue", resultJob) - if err != nil { - return err - } - return w.queue.StoreResult(job.ID, proof) + return proof, nil } func (w *BaseQueueWorker) processInclusionProof(payload json.RawMessage, meta common.ProofRequestMeta) (*common.Proof, error) { @@ -275,13 +553,14 @@ func (w *BaseQueueWorker) processInclusionProof(payload json.RawMessage, meta co return nil, fmt.Errorf("inclusion proof: %w", err) } - if meta.Version == 1 { + switch meta.Version { + case 1: var params v1.InclusionParameters if err := json.Unmarshal(payload, ¶ms); err != nil { return nil, fmt.Errorf("failed to unmarshal legacy inclusion parameters: %w", err) } return v1.ProveInclusion(ps, ¶ms) - } else if meta.Version == 2 { + case 2: var params v2.InclusionParameters if err := json.Unmarshal(payload, ¶ms); err != nil { return nil, fmt.Errorf("failed to unmarshal inclusion parameters: %w", err) @@ -294,7 +573,7 @@ func (w *BaseQueueWorker) processInclusionProof(payload json.RawMessage, meta co func (w *BaseQueueWorker) processNonInclusionProof(payload json.RawMessage, meta common.ProofRequestMeta) (*common.Proof, error) { ps, err := w.keyManager.GetMerkleSystem( - 0, + 0, 0, meta.AddressTreeHeight, meta.NumAddresses, @@ -333,13 +612,14 @@ func (w *BaseQueueWorker) processCombinedProof(payload json.RawMessage, meta com return nil, fmt.Errorf("combined proof: %w", err) } - if meta.AddressTreeHeight == 26 { + switch meta.AddressTreeHeight { + case 26: var params v1.CombinedParameters if err := json.Unmarshal(payload, ¶ms); err != nil { return nil, fmt.Errorf("failed to unmarshal legacy combined parameters: %w", err) } return v1.ProveCombined(ps, ¶ms) - } else if meta.AddressTreeHeight == 40 { + case 40: var params v2.CombinedParameters if err := json.Unmarshal(payload, ¶ms); err != nil { return nil, fmt.Errorf("failed to unmarshal combined parameters: %w", err) @@ -408,7 +688,7 @@ func (w *BaseQueueWorker) processBatchAddressAppendProof(payload json.RawMessage func (w *BaseQueueWorker) removeFromProcessingQueue(jobID string) { processingQueueLength, _ := w.queue.Client.LLen(w.queue.Ctx, w.processingQueueName).Result() - for i := int64(0); i < processingQueueLength; i++ { + for i := range processingQueueLength { item, err := w.queue.Client.LIndex(w.queue.Ctx, w.processingQueueName, i).Result() if err != nil { continue @@ -422,11 +702,27 @@ func (w *BaseQueueWorker) removeFromProcessingQueue(jobID string) { } } -func (w *BaseQueueWorker) addToFailedQueue(job *ProofJob, err error) { +func (w *BaseQueueWorker) addToFailedQueue(job *ProofJob, inputHash string, err error) { + // Extract circuit type from payload for debugging, but don't store full payload + // to prevent memory issues (payloads can be hundreds of KB) + var circuitType string + var payloadMeta map[string]interface{} + if json.Unmarshal(job.Payload, &payloadMeta) == nil { + if ct, ok := payloadMeta["circuitType"].(string); ok { + circuitType = ct + } + } + failedJob := map[string]interface{}{ - "original_job": job, - "error": err.Error(), - "failed_at": time.Now(), + "original_job": map[string]interface{}{ + "id": job.ID, + "type": job.Type, + "circuit_type": circuitType, + "payload_size": len(job.Payload), + "created_at": job.CreatedAt, + }, + "error": err.Error(), + "failed_at": time.Now(), } failedData, _ := json.Marshal(failedJob) @@ -437,8 +733,21 @@ func (w *BaseQueueWorker) addToFailedQueue(job *ProofJob, err error) { CreatedAt: time.Now(), } - err = w.queue.EnqueueProof("zk_failed_queue", failedJobStruct) - if err != nil { - return + enqueueErr := w.queue.EnqueueProof("zk_failed_queue", failedJobStruct) + if enqueueErr != nil { + logging.Logger().Error(). + Err(enqueueErr). + Str("job_id", job.ID). + Msg("Failed to add job to failed queue") + } + + // Index the failure for O(1) cached lookups + if inputHash != "" { + if indexErr := w.queue.IndexFailureByHash(inputHash, job.ID); indexErr != nil { + logging.Logger().Warn(). + Err(indexErr). + Str("job_id", job.ID). + Msg("Failed to index failure (non-critical)") + } } } diff --git a/prover/server/server/server.go b/prover/server/server/server.go index 23b089c815..a332e30c13 100644 --- a/prover/server/server/server.go +++ b/prover/server/server/server.go @@ -11,6 +11,7 @@ import ( v1 "light/light-prover/prover/v1" v2 "light/light-prover/prover/v2" "net/http" + "strings" "time" "github.com/google/uuid" @@ -73,9 +74,12 @@ func (handler proofStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - if err != nil { - return + if err := json.NewEncoder(w).Encode(response); err != nil { + logging.Logger().Error(). + Err(err). + Str("job_id", jobID). + Str("response_type", "completed_result"). + Msg("Failed to encode JSON response") } return } @@ -83,9 +87,69 @@ func (handler proofStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque jobExists, jobStatus, jobInfo := handler.checkJobExistsDetailed(jobID) if !jobExists { + // Fallback: check job metadata - this catches jobs that were submitted but not yet + // visible in queues due to Redis replica lag or race conditions + jobMeta, metaErr := handler.redisQueue.GetJobMeta(jobID) + if metaErr != nil { + logging.Logger().Warn(). + Err(metaErr). + Str("job_id", jobID). + Msg("Error checking job metadata") + } + + if jobMeta != nil { + // Job was submitted (we have metadata) but not found in queues - return queued status + logging.Logger().Info(). + Str("job_id", jobID). + Interface("job_meta", jobMeta). + Msg("Job not found in queues but metadata exists - returning queued status") + + response := map[string]interface{}{ + "job_id": jobID, + "status": "queued", + "message": "Job is queued and waiting to be processed", + } + if circuitType, ok := jobMeta["circuit_type"]; ok { + response["circuit_type"] = circuitType + } + if submittedAt, ok := jobMeta["submitted_at"]; ok { + response["submitted_at"] = submittedAt + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + if err := json.NewEncoder(w).Encode(response); err != nil { + logging.Logger().Error(). + Err(err). + Str("job_id", jobID). + Str("response_type", "queued_status"). + Msg("Failed to encode JSON response") + } + return + } + logging.Logger().Warn(). Str("job_id", jobID). - Msg("Job not found in any queue") + Msg("Job not found in any queue or metadata") + + if handler.redisQueue != nil && handler.redisQueue.Client != nil { + if stats, statsErr := handler.redisQueue.GetQueueStats(); statsErr == nil { + logging.Logger().Info(). + Str("job_id", jobID). + Interface("queue_stats", stats). + Str("redis_addr", handler.redisQueue.Client.Options().Addr). + Msg("Queue stats at job_not_found") + } else { + logging.Logger().Warn(). + Err(statsErr). + Str("job_id", jobID). + Msg("Failed to fetch queue stats during job_not_found") + } + } + + // Clean up any stale in-flight marker for this job ID + // This allows new requests with the same input to create fresh jobs + handler.redisQueue.CleanupStaleInFlightMarker(jobID) notFoundError := &Error{ StatusCode: http.StatusNotFound, @@ -96,17 +160,35 @@ func (handler proofStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque return } - logging.Logger().Info(). + // Log job status without payload to avoid filling up log buffer + logEvent := logging.Logger().Info(). Str("job_id", jobID). - Str("status", jobStatus). - Interface("job_info", jobInfo). - Msg("Job found but not completed") + Str("status", jobStatus) + if jobInfo != nil { + if ct, ok := jobInfo["circuit_type"]; ok { + logEvent = logEvent.Interface("circuit_type", ct) + } + if ca, ok := jobInfo["created_at"]; ok { + logEvent = logEvent.Interface("created_at", ca) + } + } + logEvent.Msg("Job found but not completed") response := map[string]interface{}{ "job_id": jobID, "status": jobStatus, } + // Handle completed jobs - include result if available + if jobStatus == "completed" && jobInfo != nil { + if result, ok := jobInfo["result"]; ok { + response["result"] = result + logging.Logger().Info(). + Str("job_id", jobID). + Msg("Returning result from checkJobExistsDetailed") + } + } + // Handle failed jobs specially - extract actual error details if jobStatus == "failed" && jobInfo != nil { if payloadRaw, ok := jobInfo["payload"]; ok { @@ -149,7 +231,18 @@ func (handler proofStatusHandler) ServeHTTP(w http.ResponseWriter, r *http.Reque } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusAccepted) + + // Return 200 OK if job is completed with result, otherwise 202 Accepted + if jobStatus == "completed" { + if _, hasResult := response["result"]; hasResult { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusAccepted) + } + } else { + w.WriteHeader(http.StatusAccepted) + } + err = json.NewEncoder(w).Encode(response) if err != nil { return @@ -177,35 +270,89 @@ func getStatusMessage(status string) string { } func (handler proofStatusHandler) checkJobExistsDetailed(jobID string) (bool, string, map[string]interface{}) { - if job, found := handler.findJobInQueue("zk_update_queue", jobID); found { - return true, "queued", job - } + // First check result cache (fast O(1) lookup) + result, err := handler.redisQueue.GetResult(jobID) + if err == nil && result != nil { + logging.Logger().Debug(). + Str("job_id", jobID). + Msg("Job found in result cache") - if job, found := handler.findJobInQueue("zk_append_queue", jobID); found { - return true, "queued", job + jobInfo := map[string]interface{}{ + "result": result, + "result_cached": true, + } + return true, "completed", jobInfo } - if job, found := handler.findJobInQueue("zk_address_append_queue", jobID); found { - return true, "queued", job + // Check job metadata to determine which queue to search (avoids full scan of all queues) + jobMeta, metaErr := handler.redisQueue.GetJobMeta(jobID) + if metaErr == nil && jobMeta != nil { + // We have metadata - check only the relevant queues based on queue name + if queueName, ok := jobMeta["queue"].(string); ok { + // Check main queue first + if job, found := handler.findJobInQueue(queueName, jobID); found { + return true, "queued", job + } + // Check processing queue for this circuit type + // Validate queueName before slicing to avoid panic + if len(queueName) >= 6 && strings.HasSuffix(queueName, "_queue") { + base := queueName[:len(queueName)-6] + processingQueue := base + "_processing_queue" + if job, found := handler.findJobInQueue(processingQueue, jobID); found { + return true, "processing", job + } + } else { + logging.Logger().Warn(). + Str("job_id", jobID). + Str("queue_name", queueName). + Msg("Malformed queue name in job metadata - skipping processing queue check") + } + } + // Job has metadata but not found in expected queues - may be in results or failed + if job, found := handler.findJobInQueue("zk_failed_queue", jobID); found { + return true, "failed", job + } + // Check results queue + if job, found := handler.findJobInQueue("zk_results_queue", jobID); found { + handler.extractResultFromPayload(job) + return true, "completed", job + } + // Return metadata-based status even if not found in queues + // This handles race conditions where job moved between queues + status := "queued" + if metaStatus, ok := jobMeta["status"].(string); ok { + status = metaStatus + } + return true, status, map[string]interface{}{ + "circuit_type": jobMeta["circuit_type"], + "submitted_at": jobMeta["submitted_at"], + "from_meta": true, + } } - if job, found := handler.findJobInQueue("zk_update_processing_queue", jobID); found { - return true, "processing", job + // No metadata - fall back to checking failed and results queues only + // (These are the terminal states where jobs might exist without metadata) + if job, found := handler.findJobInQueue("zk_failed_queue", jobID); found { + return true, "failed", job } - if job, found := handler.findJobInQueue("zk_append_processing_queue", jobID); found { - return true, "processing", job + if job, found := handler.findJobInQueue("zk_results_queue", jobID); found { + handler.extractResultFromPayload(job) + return true, "completed", job } - if job, found := handler.findJobInQueue("zk_address_append_processing_queue", jobID); found { - return true, "processing", job - } + return false, "", nil +} - if job, found := handler.findJobInQueue("zk_failed_queue", jobID); found { - return true, "failed", job +func (handler proofStatusHandler) extractResultFromPayload(job map[string]interface{}) { + if payloadRaw, ok := job["payload"]; ok { + if payloadStr, ok := payloadRaw.(string); ok { + var payloadData map[string]interface{} + if json.Unmarshal([]byte(payloadStr), &payloadData) == nil { + job["result"] = payloadData + } + } } - - return false, "", nil } func (handler proofStatusHandler) findJobInQueue(queueName, jobID string) (map[string]interface{}, bool) { @@ -513,23 +660,95 @@ func RunEnhanced(config *EnhancedConfig, redisQueue *RedisQueue, keyManager *com return } - jobID := uuid.New().String() + queueName := GetQueueNameForCircuit(proofRequestMeta.CircuitType) + + // Compute input hash for deduplication + inputHash := ComputeInputHash(json.RawMessage(buf)) + + // Check for existing in-flight job with same input + dedupResult, err := redisQueue.DeduplicateJob(inputHash) + if err != nil { + logging.Logger().Error(). + Err(err). + Str("input_hash", inputHash). + Msg("Failed to deduplicate job") + http.Error(w, "Failed to register job", http.StatusInternalServerError) + return + } + + // If deduplicated to an existing job, return early + if dedupResult.IsDeduplicated { + response := map[string]interface{}{ + "job_id": dedupResult.JobID, + "status": "already_queued", + "queue": queueName, + "circuit_type": string(proofRequestMeta.CircuitType), + "message": "Proof request with identical input already in queue. Returning existing job ID.", + "deduplicated": true, + } + + logging.Logger().Info(). + Str("existing_job_id", dedupResult.JobID). + Str("input_hash", inputHash). + Str("circuit_type", string(proofRequestMeta.CircuitType)). + Msg("Deduplicated proof request via /queue/add") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + if err := json.NewEncoder(w).Encode(response); err != nil { + logging.Logger().Error(). + Err(err). + Str("job_id", dedupResult.JobID). + Str("response_type", "deduplicated_queue_add_response"). + Msg("Failed to encode JSON response") + } + return + } + + // This is a new job + jobID := dedupResult.JobID job := &ProofJob{ - ID: jobID, - Type: "zk_proof", - Payload: json.RawMessage(buf), - CreatedAt: time.Now(), + ID: jobID, + Type: "zk_proof", + Payload: json.RawMessage(buf), + CreatedAt: time.Now(), + TreeID: proofRequestMeta.TreeID, + BatchIndex: proofRequestMeta.BatchIndex, } - queueName := GetQueueNameForCircuit(proofRequestMeta.CircuitType) + // Store job metadata BEFORE enqueueing to prevent race condition where worker + // picks up job before metadata exists, causing job_not_found on status checks + if err := redisQueue.StoreJobMeta(jobID, queueName, string(proofRequestMeta.CircuitType)); err != nil { + logging.Logger().Warn(). + Err(err). + Str("job_id", jobID). + Str("queue", queueName). + Msg("Failed to store job metadata (will still attempt to enqueue)") + } + + // Store input hash mapping for cleanup when job completes + redisQueue.StoreInputHash(jobID, inputHash) err = redisQueue.EnqueueProof(queueName, job) if err != nil { + // Clean up in-flight marker and metadata since we failed to enqueue + if delErr := redisQueue.DeleteInFlightJob(inputHash, jobID); delErr != nil { + logging.Logger().Error().Err(delErr).Str("job_id", jobID).Msg("Failed to cleanup in-flight marker after enqueue failure - may cause stale deduplication") + } + if delErr := redisQueue.DeleteJobMeta(jobID); delErr != nil { + logging.Logger().Error().Err(delErr).Str("job_id", jobID).Msg("Failed to cleanup job metadata after enqueue failure") + } unexpectedError(err).send(w) return } + logging.Logger().Info(). + Str("job_id", jobID). + Str("queue", queueName). + Str("circuit_type", string(proofRequestMeta.CircuitType)). + Msg("Enqueued proof job") + response := map[string]interface{}{ "job_id": jobID, "status": "queued", @@ -647,24 +866,91 @@ type healthHandler struct { } func (handler proveHandler) handleAsyncProof(w http.ResponseWriter, r *http.Request, buf []byte, meta common.ProofRequestMeta) { - jobID := uuid.New().String() - ProofRequestsTotal.WithLabelValues(string(meta.CircuitType)).Inc() RecordCircuitInputSize(string(meta.CircuitType), len(buf)) + queueName := GetQueueNameForCircuit(meta.CircuitType) + + // Compute input hash for deduplication + inputHash := ComputeInputHash(json.RawMessage(buf)) + + // Check for existing in-flight job with same input + dedupResult, err := handler.redisQueue.DeduplicateJob(inputHash) + if err != nil { + logging.Logger().Error(). + Err(err). + Str("input_hash", inputHash). + Msg("Failed to deduplicate job") + http.Error(w, "Failed to register job", http.StatusInternalServerError) + return + } + + // If deduplicated to an existing job, return early + if dedupResult.IsDeduplicated { + response := map[string]interface{}{ + "job_id": dedupResult.JobID, + "status": "already_queued", + "circuit_type": string(meta.CircuitType), + "queue": queueName, + "message": "Proof request with identical input already in queue. Returning existing job ID.", + "deduplicated": true, + } + + logging.Logger().Info(). + Str("existing_job_id", dedupResult.JobID). + Str("input_hash", inputHash). + Str("circuit_type", string(meta.CircuitType)). + Msg("Deduplicated proof request - returning existing job") + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + if err := json.NewEncoder(w).Encode(response); err != nil { + logging.Logger().Error(). + Err(err). + Str("job_id", dedupResult.JobID). + Str("response_type", "deduplicated_async_response"). + Msg("Failed to encode JSON response") + } + return + } + + // This is a new job + jobID := dedupResult.JobID + job := &ProofJob{ - ID: jobID, - Type: "zk_proof", - Payload: json.RawMessage(buf), - CreatedAt: time.Now(), + ID: jobID, + Type: "zk_proof", + Payload: json.RawMessage(buf), + CreatedAt: time.Now(), + TreeID: meta.TreeID, + BatchIndex: meta.BatchIndex, } - queueName := GetQueueNameForCircuit(meta.CircuitType) + // Store job metadata BEFORE enqueueing to prevent race condition where worker + // picks up job before metadata exists, causing job_not_found on status checks + if err := handler.redisQueue.StoreJobMeta(jobID, queueName, string(meta.CircuitType)); err != nil { + logging.Logger().Warn(). + Err(err). + Str("job_id", jobID). + Str("queue", queueName). + Msg("Failed to store job metadata (will still attempt to enqueue)") + } - err := handler.redisQueue.EnqueueProof(queueName, job) + // Store input hash mapping for cleanup when job completes + handler.redisQueue.StoreInputHash(jobID, inputHash) + + err = handler.redisQueue.EnqueueProof(queueName, job) if err != nil { logging.Logger().Error().Err(err).Msg("Failed to enqueue proof job") + // Clean up in-flight marker and metadata since we failed to enqueue + if delErr := handler.redisQueue.DeleteInFlightJob(inputHash, jobID); delErr != nil { + logging.Logger().Error().Err(delErr).Str("job_id", jobID).Msg("Failed to cleanup in-flight marker after enqueue failure - may cause stale deduplication") + } + if delErr := handler.redisQueue.DeleteJobMeta(jobID); delErr != nil { + logging.Logger().Error().Err(delErr).Str("job_id", jobID).Msg("Failed to cleanup job metadata after enqueue failure") + } + if handler.isBatchOperation(meta.CircuitType) { serviceUnavailableError := &Error{ StatusCode: http.StatusServiceUnavailable, @@ -735,6 +1021,21 @@ func (handler proveHandler) handleSyncProof(w http.ResponseWriter, r *http.Reque resultChan := make(chan proofResult, 1) go func() { + // Recover from panics to prevent server crash from malformed input + defer func() { + if r := recover(); r != nil { + ProofPanicsTotal.WithLabelValues(string(meta.CircuitType)).Inc() + logging.Logger().Error(). + Interface("panic", r). + Str("circuit_type", string(meta.CircuitType)). + Msg("Panic recovered in proof processing") + resultChan <- proofResult{ + proof: nil, + err: unexpectedError(fmt.Errorf("internal error during proof processing: %v", r)), + } + } + }() + timer := StartProofTimer(string(meta.CircuitType)) RecordCircuitInputSize(string(meta.CircuitType), len(buf)) @@ -964,7 +1265,8 @@ func (handler proveHandler) inclusionProof(buf []byte, proofRequestMeta common.P return nil, provingError(fmt.Errorf("inclusion proof: %w", err)) } - if proofRequestMeta.Version == 1 { + switch proofRequestMeta.Version { + case 1: var params v1.InclusionParameters if err := json.Unmarshal(buf, ¶ms); err != nil { @@ -975,7 +1277,7 @@ func (handler proveHandler) inclusionProof(buf []byte, proofRequestMeta common.P return nil, provingError(err) } return proof, nil - } else if proofRequestMeta.Version == 2 { + case 2: var params v2.InclusionParameters if err := json.Unmarshal(buf, ¶ms); err != nil { return nil, malformedBodyError(err) @@ -1007,7 +1309,8 @@ func (handler proveHandler) nonInclusionProof(buf []byte, proofRequestMeta commo return nil, provingError(fmt.Errorf("non-inclusion proof: %w", err)) } - if proofRequestMeta.AddressTreeHeight == 26 { + switch proofRequestMeta.AddressTreeHeight { + case 26: var params v1.NonInclusionParameters var err = json.Unmarshal(buf, ¶ms) @@ -1022,7 +1325,7 @@ func (handler proveHandler) nonInclusionProof(buf []byte, proofRequestMeta commo return nil, provingError(err) } return proof, nil - } else if proofRequestMeta.AddressTreeHeight == 40 { + case 40: var params v2.NonInclusionParameters var err = json.Unmarshal(buf, ¶ms) @@ -1037,7 +1340,7 @@ func (handler proveHandler) nonInclusionProof(buf []byte, proofRequestMeta commo return nil, provingError(err) } return proof, nil - } else { + default: return nil, provingError(fmt.Errorf("no proving system for %+v proofRequest", proofRequestMeta)) } } @@ -1059,7 +1362,8 @@ func (handler proveHandler) combinedProof(buf []byte, proofRequestMeta common.Pr return nil, provingError(fmt.Errorf("combined proof: %w", err)) } - if proofRequestMeta.AddressTreeHeight == 26 { + switch proofRequestMeta.AddressTreeHeight { + case 26: var params v1.CombinedParameters if err := json.Unmarshal(buf, ¶ms); err != nil { return nil, malformedBodyError(err) @@ -1069,7 +1373,7 @@ func (handler proveHandler) combinedProof(buf []byte, proofRequestMeta common.Pr return nil, provingError(err) } return proof, nil - } else if proofRequestMeta.AddressTreeHeight == 40 { + case 40: var params v2.CombinedParameters if err := json.Unmarshal(buf, ¶ms); err != nil { return nil, malformedBodyError(err) @@ -1079,7 +1383,7 @@ func (handler proveHandler) combinedProof(buf []byte, proofRequestMeta common.Pr return nil, provingError(err) } return proof, nil - } else { + default: return nil, provingError(fmt.Errorf("no proving system for %+v proofRequest", proofRequestMeta)) } } @@ -1091,6 +1395,11 @@ func (handler healthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } logging.Logger().Info().Msg("received health check request") responseBytes, err := json.Marshal(map[string]string{"status": "ok"}) + if err != nil { + logging.Logger().Error().Err(err).Msg("error marshaling response") + w.WriteHeader(http.StatusInternalServerError) + return + } w.WriteHeader(http.StatusOK) _, err = w.Write(responseBytes) if err != nil { diff --git a/scripts/devenv.sh b/scripts/devenv.sh index 3584f8652c..605bce12b7 100755 --- a/scripts/devenv.sh +++ b/scripts/devenv.sh @@ -67,7 +67,7 @@ PATH="${LIGHT_PROTOCOL_TOPLEVEL}/.local/go/bin:${PATH}" PATH="${LIGHT_PROTOCOL_TOPLEVEL}/.local/npm-global/bin:${PATH}" # Remove the original Rust-related PATH entries -PATH=$(echo "$PATH" | tr ':' '\n' | grep -vE "/.rustup/|/.cargo/" | tr '\n' ':' | sed 's/:$//') +PATH=$(echo "$PATH" | tr ':' '\n' | grep -vE "/.rustup/|/.cargo/|/.local/share/solana/" | tr '\n' ':' | sed 's/:$//') # Define alias of `light` to use the CLI built from source (only if not in CI) if [ -z "${CI:-}" ]; then diff --git a/scripts/devenv/versions.sh b/scripts/devenv/versions.sh index 14a1c310f2..394405759e 100755 --- a/scripts/devenv/versions.sh +++ b/scripts/devenv/versions.sh @@ -13,7 +13,7 @@ export SOLANA_VERSION="2.2.15" export ANCHOR_VERSION="0.31.1" export JQ_VERSION="1.8.0" export PHOTON_VERSION="0.51.2" -export PHOTON_COMMIT="711c47b20330c6bb78feb0a2c15e8292fcd0a7b0" +export PHOTON_COMMIT="3dbfb8e6772779fc89c640b5b0823b95d1958efc" export REDIS_VERSION="8.0.1" export ANCHOR_TAG="anchor-v${ANCHOR_VERSION}" diff --git a/sdk-libs/client/src/indexer/indexer_trait.rs b/sdk-libs/client/src/indexer/indexer_trait.rs index 51066c07b6..b051ab3c1d 100644 --- a/sdk-libs/client/src/indexer/indexer_trait.rs +++ b/sdk-libs/client/src/indexer/indexer_trait.rs @@ -4,13 +4,14 @@ use solana_pubkey::Pubkey; use super::{ response::{Items, ItemsWithCursor, Response}, types::{ - CompressedAccount, CompressedTokenAccount, OwnerBalance, QueueElementsResult, - QueueInfoResult, SignatureWithMetadata, TokenBalance, ValidityProofWithContext, + CompressedAccount, CompressedTokenAccount, OwnerBalance, QueueInfoResult, + SignatureWithMetadata, TokenBalance, ValidityProofWithContext, }, Address, AddressWithTree, GetCompressedAccountsByOwnerConfig, GetCompressedTokenAccountsByOwnerOrDelegateOptions, Hash, IndexerError, IndexerRpcConfig, MerkleProof, NewAddressProofWithContext, PaginatedOptions, QueueElementsV2Options, RetryConfig, }; +use crate::indexer::QueueElementsResult; // TODO: remove all references in input types. #[async_trait] pub trait Indexer: std::marker::Send + std::marker::Sync { @@ -171,12 +172,8 @@ pub trait Indexer: std::marker::Send + std::marker::Sync { config: Option, ) -> Result, IndexerError>; - /// Returns queue elements from the queue with the given merkle tree pubkey. - /// Can fetch from output queue (append), input queue (nullify), address queue, or combinations. - /// For input queues account compression program does not store queue elements in the - /// account data but only emits these in the public transaction event. The - /// indexer needs the queue elements to create batch update proofs. /// Returns queue elements with deduplicated nodes for efficient staging tree construction. + /// Supports output queue, input queue, and address queue. async fn get_queue_elements( &mut self, merkle_tree_pubkey: [u8; 32], diff --git a/sdk-libs/client/src/indexer/photon_indexer.rs b/sdk-libs/client/src/indexer/photon_indexer.rs index a75a14b40c..a220c16554 100644 --- a/sdk-libs/client/src/indexer/photon_indexer.rs +++ b/sdk-libs/client/src/indexer/photon_indexer.rs @@ -1575,6 +1575,20 @@ impl Indexer for PhotonIndexer { ..Default::default() }; + tracing::info!( + "get_queue_elements request: output_queue={:?}, input_queue={:?}", + request.params.output_queue.as_ref().map(|q| ( + q.limit, + q.start_index, + q.zkp_batch_size + )), + request.params.input_queue.as_ref().map(|q| ( + q.limit, + q.start_index, + q.zkp_batch_size + )), + ); + let result = photon_api::apis::default_api::get_queue_elements_post( &self.configuration, request, @@ -1719,22 +1733,8 @@ impl Indexer for PhotonIndexer { .map(|h| Hash::from_base58(h)) .collect(); - // Parse low_element_proofs for debugging/validation - let low_element_proofs: Result>, IndexerError> = address - .low_element_proofs - .iter() - .map(|proof| { - proof.iter().map(|h| Hash::from_base58(h)).collect::, - IndexerError, - >>( - ) - }) - .collect(); - Some(super::AddressQueueData { addresses: addresses?, - queue_indices: address.queue_indices, low_element_values: low_element_values?, low_element_next_values: low_element_next_values?, low_element_indices: address.low_element_indices, @@ -1746,7 +1746,6 @@ impl Indexer for PhotonIndexer { subtrees: subtrees?, start_index: address.start_index, root_seq: address.root_seq, - low_element_proofs: low_element_proofs?, }) } else { None diff --git a/sdk-libs/client/src/indexer/types.rs b/sdk-libs/client/src/indexer/types.rs index 877e99ac16..d2f8ef89a2 100644 --- a/sdk-libs/client/src/indexer/types.rs +++ b/sdk-libs/client/src/indexer/types.rs @@ -91,7 +91,6 @@ pub struct StateQueueData { #[derive(Debug, Clone, PartialEq, Default)] pub struct AddressQueueData { pub addresses: Vec<[u8; 32]>, - pub queue_indices: Vec, pub low_element_values: Vec<[u8; 32]>, pub low_element_next_values: Vec<[u8; 32]>, pub low_element_indices: Vec, @@ -105,8 +104,6 @@ pub struct AddressQueueData { pub subtrees: Vec<[u8; 32]>, pub start_index: u64, pub root_seq: u64, - /// Original low element proofs from indexer (for debugging/validation) - pub low_element_proofs: Vec>, } impl AddressQueueData { diff --git a/sdk-libs/client/src/rpc/client.rs b/sdk-libs/client/src/rpc/client.rs index 13930cbc9e..8a3226c72e 100644 --- a/sdk-libs/client/src/rpc/client.rs +++ b/sdk-libs/client/src/rpc/client.rs @@ -17,11 +17,12 @@ use solana_commitment_config::CommitmentConfig; use solana_hash::Hash; use solana_instruction::Instruction; use solana_keypair::Keypair; +use solana_message::{v0, AddressLookupTableAccount, VersionedMessage}; use solana_pubkey::{pubkey, Pubkey}; use solana_rpc_client::rpc_client::RpcClient; use solana_rpc_client_api::config::{RpcSendTransactionConfig, RpcTransactionConfig}; use solana_signature::Signature; -use solana_transaction::Transaction; +use solana_transaction::{versioned::VersionedTransaction, Transaction}; use solana_transaction_status_client_types::{ option_serializer::OptionSerializer, TransactionStatus, UiInstruction, UiTransactionEncoding, }; @@ -681,6 +682,42 @@ impl Rpc for LightClient { .await } + /// Creates and sends a versioned transaction with address lookup tables. + /// + /// `address_lookup_tables` must contain pre-fetched `AddressLookupTableAccount` values + /// loaded from the chain. Callers are responsible for resolving these accounts before + /// calling this method. Unresolved or missing lookup tables will cause compilation to fail. + /// + /// Returns `RpcError::CustomError` on message compilation failure, + /// `RpcError::SigningError` on signing failure. + async fn create_and_send_versioned_transaction<'a>( + &'a mut self, + instructions: &'a [Instruction], + payer: &'a Pubkey, + signers: &'a [&'a Keypair], + address_lookup_tables: &'a [AddressLookupTableAccount], + ) -> Result { + let blockhash = self.get_latest_blockhash().await?.0; + + let message = + v0::Message::try_compile(payer, instructions, address_lookup_tables, blockhash) + .map_err(|e| { + RpcError::CustomError(format!("Failed to compile v0 message: {}", e)) + })?; + + let versioned_message = VersionedMessage::V0(message); + + let transaction = VersionedTransaction::try_new(versioned_message, signers) + .map_err(|e| RpcError::SigningError(e.to_string()))?; + + self.retry(|| async { + self.client + .send_and_confirm_transaction(&transaction) + .map_err(RpcError::from) + }) + .await + } + fn indexer(&self) -> Result<&impl Indexer, RpcError> { self.indexer.as_ref().ok_or(RpcError::IndexerNotInitialized) } diff --git a/sdk-libs/client/src/rpc/rpc_trait.rs b/sdk-libs/client/src/rpc/rpc_trait.rs index 338059cdef..2ece7386fd 100644 --- a/sdk-libs/client/src/rpc/rpc_trait.rs +++ b/sdk-libs/client/src/rpc/rpc_trait.rs @@ -9,6 +9,7 @@ use solana_commitment_config::CommitmentConfig; use solana_hash::Hash; use solana_instruction::Instruction; use solana_keypair::Keypair; +use solana_message::AddressLookupTableAccount; use solana_pubkey::Pubkey; use solana_rpc_client_api::config::RpcSendTransactionConfig; use solana_signature::Signature; @@ -53,7 +54,7 @@ impl LightClientConfig { pub fn local() -> Self { Self { url: RpcUrl::Localnet.to_string(), - commitment_config: Some(CommitmentConfig::confirmed()), + commitment_config: Some(CommitmentConfig::processed()), photon_url: Some("http://127.0.0.1:8784".to_string()), api_key: None, fetch_active_tree: false, @@ -182,6 +183,14 @@ pub trait Rpc: Send + Sync + Debug + 'static { self.process_transaction(transaction).await } + async fn create_and_send_versioned_transaction<'a>( + &'a mut self, + instructions: &'a [Instruction], + payer: &'a Pubkey, + signers: &'a [&'a Keypair], + address_lookup_tables: &'a [AddressLookupTableAccount], + ) -> Result; + async fn create_and_send_transaction_with_public_event( &mut self, instruction: &[Instruction], diff --git a/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_200_response.rs b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_200_response.rs new file mode 100644 index 0000000000..13441587c3 --- /dev/null +++ b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_200_response.rs @@ -0,0 +1,60 @@ +/* + * photon-indexer + * + * Solana indexer for general compression + * + * The version of the OpenAPI document: 0.50.0 + * + * Generated by: https://openapi-generator.tech + */ + +use crate::models; + +#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] +pub struct GetQueueElementsV2Post200Response { + #[serde(rename = "error", skip_serializing_if = "Option::is_none")] + pub error: Option>, + /// An ID to identify the response. + #[serde(rename = "id")] + pub id: Id, + /// The version of the JSON-RPC protocol. + #[serde(rename = "jsonrpc")] + pub jsonrpc: Jsonrpc, + #[serde(rename = "result", skip_serializing_if = "Option::is_none")] + pub result: Option>, +} + +impl GetQueueElementsV2Post200Response { + pub fn new(id: Id, jsonrpc: Jsonrpc) -> GetQueueElementsV2Post200Response { + GetQueueElementsV2Post200Response { + error: None, + id, + jsonrpc, + result: None, + } + } +} +/// An ID to identify the response. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub enum Id { + #[serde(rename = "test-account")] + TestAccount, +} + +impl Default for Id { + fn default() -> Id { + Self::TestAccount + } +} +/// The version of the JSON-RPC protocol. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub enum Jsonrpc { + #[serde(rename = "2.0")] + Variant2Period0, +} + +impl Default for Jsonrpc { + fn default() -> Jsonrpc { + Self::Variant2Period0 + } +} diff --git a/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_200_response_result.rs b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_200_response_result.rs new file mode 100644 index 0000000000..34b395e7e8 --- /dev/null +++ b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_200_response_result.rs @@ -0,0 +1,31 @@ +/* + * photon-indexer + * + * Solana indexer for general compression + * + * The version of the OpenAPI document: 0.50.0 + * + * Generated by: https://openapi-generator.tech + */ + +use crate::models; + +#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] +pub struct GetQueueElementsV2Post200ResponseResult { + #[serde(rename = "context")] + pub context: Box, + #[serde(rename = "stateQueue", skip_serializing_if = "Option::is_none")] + pub state_queue: Option>, + #[serde(rename = "addressQueue", skip_serializing_if = "Option::is_none")] + pub address_queue: Option>, +} + +impl GetQueueElementsV2Post200ResponseResult { + pub fn new(context: models::Context) -> GetQueueElementsV2Post200ResponseResult { + GetQueueElementsV2Post200ResponseResult { + context: Box::new(context), + state_queue: None, + address_queue: None, + } + } +} diff --git a/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_request.rs b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_request.rs new file mode 100644 index 0000000000..9adc582be7 --- /dev/null +++ b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_request.rs @@ -0,0 +1,78 @@ +/* + * photon-indexer + * + * Solana indexer for general compression + * + * The version of the OpenAPI document: 0.50.0 + * + * Generated by: https://openapi-generator.tech + */ + +use crate::models; + +#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] +pub struct GetQueueElementsV2PostRequest { + /// An ID to identify the request. + #[serde(rename = "id")] + pub id: Id, + /// The version of the JSON-RPC protocol. + #[serde(rename = "jsonrpc")] + pub jsonrpc: Jsonrpc, + /// The name of the method to invoke. + #[serde(rename = "method")] + pub method: Method, + #[serde(rename = "params")] + pub params: Box, +} + +impl GetQueueElementsV2PostRequest { + pub fn new( + id: Id, + jsonrpc: Jsonrpc, + method: Method, + params: models::GetQueueElementsV2PostRequestParams, + ) -> GetQueueElementsV2PostRequest { + GetQueueElementsV2PostRequest { + id, + jsonrpc, + method, + params: Box::new(params), + } + } +} +/// An ID to identify the request. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub enum Id { + #[serde(rename = "test-account")] + TestAccount, +} + +impl Default for Id { + fn default() -> Id { + Self::TestAccount + } +} +/// The version of the JSON-RPC protocol. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub enum Jsonrpc { + #[serde(rename = "2.0")] + Variant2Period0, +} + +impl Default for Jsonrpc { + fn default() -> Jsonrpc { + Self::Variant2Period0 + } +} +/// The name of the method to invoke. +#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)] +pub enum Method { + #[serde(rename = "getQueueElements")] + GetQueueElementsV2, +} + +impl Default for Method { + fn default() -> Method { + Self::GetQueueElementsV2 + } +} diff --git a/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_request_params.rs b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_request_params.rs new file mode 100644 index 0000000000..36ffe2cbbf --- /dev/null +++ b/sdk-libs/photon-api/src/models/_get_queue_elements_v2_post_request_params.rs @@ -0,0 +1,38 @@ +/* + * photon-indexer + * + * Solana indexer for general compression + * + * The version of the OpenAPI document: 0.50.0 + * + * Generated by: https://openapi-generator.tech + */ + +use crate::models; + +#[derive(Clone, Default, Debug, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetQueueElementsV2PostRequestParams { + /// The merkle tree public key + pub tree: String, + + #[serde(skip_serializing_if = "Option::is_none")] + pub output_queue: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub input_queue: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub address_queue: Option, +} + +impl GetQueueElementsV2PostRequestParams { + pub fn new(tree: String) -> GetQueueElementsV2PostRequestParams { + GetQueueElementsV2PostRequestParams { + tree, + output_queue: None, + input_queue: None, + address_queue: None, + } + } +} diff --git a/sdk-libs/photon-api/src/models/address_queue_data_v2.rs b/sdk-libs/photon-api/src/models/address_queue_data_v2.rs index 8c6f6296e5..e4ba4239ca 100644 --- a/sdk-libs/photon-api/src/models/address_queue_data_v2.rs +++ b/sdk-libs/photon-api/src/models/address_queue_data_v2.rs @@ -17,21 +17,15 @@ pub struct AddressQueueDataV2 { pub addresses: Vec, pub queue_indices: Vec, /// Deduplicated tree nodes for address tree non-inclusion proofs - #[serde(default)] pub nodes: Vec, pub low_element_indices: Vec, pub low_element_values: Vec, pub low_element_next_indices: Vec, pub low_element_next_values: Vec, - #[serde(default)] - pub low_element_proofs: Vec>, - #[serde(default)] pub leaves_hash_chains: Vec, pub initial_root: String, pub start_index: u64, - #[serde(default)] pub subtrees: Vec, - #[serde(default)] pub root_seq: u64, } @@ -45,7 +39,6 @@ impl AddressQueueDataV2 { low_element_values: Vec, low_element_next_indices: Vec, low_element_next_values: Vec, - low_element_proofs: Vec>, leaves_hash_chains: Vec, initial_root: String, start_index: u64, @@ -60,7 +53,6 @@ impl AddressQueueDataV2 { low_element_values, low_element_next_indices, low_element_next_values, - low_element_proofs, leaves_hash_chains, initial_root, start_index, diff --git a/sdk-libs/program-test/Cargo.toml b/sdk-libs/program-test/Cargo.toml index 49c9a2d8e6..f7f331b0a8 100644 --- a/sdk-libs/program-test/Cargo.toml +++ b/sdk-libs/program-test/Cargo.toml @@ -64,5 +64,5 @@ solana-transaction-status = { workspace = true } bs58 = { workspace = true } light-sdk-types = { workspace = true } tabled = { workspace = true } -chrono = "0.4" -base64 = "0.22" +chrono = { workspace = true } +base64 = { workspace = true } diff --git a/sdk-libs/program-test/src/indexer/test_indexer.rs b/sdk-libs/program-test/src/indexer/test_indexer.rs index e338a0428b..69d6f33fde 100644 --- a/sdk-libs/program-test/src/indexer/test_indexer.rs +++ b/sdk-libs/program-test/src/indexer/test_indexer.rs @@ -677,7 +677,6 @@ impl Indexer for TestIndexer { let mut low_element_next_values = Vec::with_capacity(addresses.len()); let mut low_element_indices = Vec::with_capacity(addresses.len()); let mut low_element_next_indices = Vec::with_capacity(addresses.len()); - let mut low_element_proofs = Vec::with_capacity(addresses.len()); // Collect all nodes for deduplication let mut node_map: HashMap = HashMap::new(); @@ -697,7 +696,6 @@ impl Indexer for TestIndexer { .push(bigint_to_be_bytes_array(&old_low_next_value).unwrap()); low_element_indices.push(old_low_element.index as u64); low_element_next_indices.push(old_low_element.next_index as u64); - low_element_proofs.push(proof); } // Convert node map to sorted vectors @@ -705,12 +703,8 @@ impl Indexer for TestIndexer { nodes.sort(); let node_hashes: Vec<[u8; 32]> = nodes.iter().map(|k| node_map[k]).collect(); - let queue_indices: Vec = - (start as u64..(start + addresses.len()) as u64).collect(); - Some(AddressQueueData { addresses, - queue_indices, low_element_values, low_element_next_values, low_element_indices, @@ -722,7 +716,6 @@ impl Indexer for TestIndexer { subtrees: address_tree_bundle.get_subtrees(), start_index: start as u64, root_seq: address_tree_bundle.sequence_number(), - low_element_proofs, }) } else { None diff --git a/sdk-libs/program-test/src/program_test/rpc.rs b/sdk-libs/program-test/src/program_test/rpc.rs index 1578aa0464..3172d7f829 100644 --- a/sdk-libs/program-test/src/program_test/rpc.rs +++ b/sdk-libs/program-test/src/program_test/rpc.rs @@ -16,6 +16,7 @@ use light_event::{ use solana_rpc_client_api::config::RpcSendTransactionConfig; use solana_sdk::{ account::Account, + address_lookup_table::AddressLookupTableAccount, clock::{Clock, Slot}, hash::Hash, instruction::Instruction, @@ -338,6 +339,18 @@ impl Rpc for LightProgramTest { tree_type: TreeType::AddressV2, } } + + async fn create_and_send_versioned_transaction<'a>( + &'a mut self, + _instructions: &'a [Instruction], + _payer: &'a Pubkey, + _signers: &'a [&'a Keypair], + _address_lookup_tables: &'a [AddressLookupTableAccount], + ) -> Result { + unimplemented!( + "create_and_send_versioned_transaction is unimplemented for LightProgramTest" + ); + } } impl LightProgramTest { diff --git a/sdk-libs/program-test/src/utils/load_accounts.rs b/sdk-libs/program-test/src/utils/load_accounts.rs index e6ad41837e..a8976232f2 100644 --- a/sdk-libs/program-test/src/utils/load_accounts.rs +++ b/sdk-libs/program-test/src/utils/load_accounts.rs @@ -111,12 +111,9 @@ pub fn load_all_accounts_from_dir() -> Result, RpcError // Decode base64 data let data = if account_data.account.data.1 == "base64" { - use base64::{engine::general_purpose, Engine as _}; - general_purpose::STANDARD - .decode(&account_data.account.data.0) - .map_err(|e| { - RpcError::CustomError(format!("Failed to decode base64 data: {}", e)) - })? + base64::decode(&account_data.account.data.0).map_err(|e| { + RpcError::CustomError(format!("Failed to decode base64 data: {}", e)) + })? } else { return Err(RpcError::CustomError(format!( "Unsupported encoding: {}", @@ -174,9 +171,7 @@ pub fn load_account_from_dir(pubkey: &Pubkey, prefix: Option<&str>) -> Result Self { + let root = Self::compute_root_from_subtrees(&subtrees, next_index); Self { subtrees, next_index, - root: [0u8; 32], + root, _hasher: PhantomData, } } + pub fn new_with_root(subtrees: [[u8; 32]; HEIGHT], next_index: usize, root: [u8; 32]) -> Self { + Self { + subtrees, + next_index, + root, + _hasher: PhantomData, + } + } + + pub fn compute_root_from_subtrees( + subtrees: &[[u8; 32]; HEIGHT], + next_index: usize, + ) -> [u8; 32] { + let mut current_index = next_index; + let mut current_hash = H::zero_bytes()[0]; + + for (subtree, zero_byte) in subtrees.iter().zip(H::zero_bytes().iter()) { + let (left, right) = if current_index.is_multiple_of(2) { + (current_hash, *zero_byte) + } else { + (*subtree, current_hash) + }; + current_hash = H::hashv(&[&left, &right]).unwrap(); + current_index /= 2; + } + + current_hash + } + pub fn new_empty() -> Self { Self { subtrees: H::zero_bytes()[0..HEIGHT].try_into().unwrap(), diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index 14cbe6b4e8..dc68bb4ddd 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -9,7 +9,7 @@ account-compression = { workspace = true } anyhow = "1.0" ark-bn254 = { workspace = true } ark-ff = { workspace = true } -base64 = "0.22" +base64 = { workspace = true } clap = { version = "4", features = ["derive"] } groth16-solana = { workspace = true } light-concurrent-merkle-tree = { workspace = true } diff --git a/xtask/src/fetch_accounts.rs b/xtask/src/fetch_accounts.rs index 108e5318a2..6d6ef8318c 100644 --- a/xtask/src/fetch_accounts.rs +++ b/xtask/src/fetch_accounts.rs @@ -1,7 +1,6 @@ use std::{fs::File, io::Write, str::FromStr}; use anyhow::Context; -use base64::{engine::general_purpose, Engine as _}; use clap::Parser; use light_program_test::{LightProgramTest, ProgramTestConfig, Rpc}; use serde_json::json; @@ -162,7 +161,7 @@ fn fetch_and_process_lut( let modified_data = decode_and_modify_lut(&account.data, add_pubkeys)?; let filename = format!("modified_lut_{}.json", pubkey); - let data_base64 = general_purpose::STANDARD.encode(&modified_data); + let data_base64 = base64::encode(&modified_data); let json_obj = json!({ "pubkey": pubkey.to_string(), "account": { @@ -245,7 +244,7 @@ fn decode_and_modify_lut(data: &[u8], add_pubkeys: &Option) -> anyhow::R } fn write_account_json(account: &Account, pubkey: &Pubkey, filename: &str) -> anyhow::Result<()> { - let data_base64 = general_purpose::STANDARD.encode(&account.data); + let data_base64 = base64::encode(&account.data); let json_obj = json!({ "pubkey": pubkey.to_string(), "account": {