diff --git a/src/inspectors/precompiles/assertion_adopter.rs b/src/inspectors/precompiles/assertion_adopter.rs index 8db3905..d75daa9 100644 --- a/src/inspectors/precompiles/assertion_adopter.rs +++ b/src/inspectors/precompiles/assertion_adopter.rs @@ -13,9 +13,65 @@ pub fn get_assertion_adopter(context: &PhEvmContext) -> Result(adopter: Address, f: F) -> R + where + F: FnOnce(&PhEvmContext) -> R, + { + let call_tracer = CallTracer::new(); + let logs_and_traces = LogsAndTraces { + tx_logs: &[], + call_traces: &call_tracer, + }; + + let context = PhEvmContext { + logs_and_traces: &logs_and_traces, + adopter, + }; + f(&context) + } + + fn test_get_assertion_adopter_helper(adopter: Address) { + let result = with_adopter_context(adopter, get_assertion_adopter); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + assert!(!encoded.is_empty()); + + // Verify we can decode the result back to the original address + let decoded = Address::abi_decode(&encoded, true); + assert!(decoded.is_ok()); + assert_eq!(decoded.unwrap(), adopter); + } + + #[test] + fn test_get_assertion_adopter_zero_address() { + test_get_assertion_adopter_helper(Address::ZERO); + } + + #[test] + fn test_get_assertion_adopter_random_address() { + test_get_assertion_adopter_helper(random_address()); + } + #[tokio::test] - async fn test_get_assertion_adopter() { + async fn test_get_assertion_adopter_integration() { let result = run_precompile_test("TestGetAdopter").await; assert!(result.is_valid()); let result_and_state = result.result_and_state; diff --git a/src/inspectors/precompiles/calls.rs b/src/inspectors/precompiles/calls.rs index d6f2c0f..0fa213c 100644 --- a/src/inspectors/precompiles/calls.rs +++ b/src/inspectors/precompiles/calls.rs @@ -1,21 +1,23 @@ use crate::{ inspectors::phevm::PhEvmContext, - inspectors::sol_primitives::PhEvm, - primitives::{ - Address, - Bytes, - FixedBytes, + inspectors::sol_primitives::PhEvm::{ + getCallInputsCall, + CallInputs as PhEvmCallInputs, }, + primitives::Bytes, }; use revm::interpreter::CallInputs; -use alloy_sol_types::SolType; +use alloy_sol_types::{ + SolCall, + SolType, +}; #[derive(thiserror::Error, Debug)] pub enum GetCallInputsError { - #[error("Invalid input length, less than 64 bytes: {0}")] - InvalidInputLength(usize), + #[error("Failed to decode getCallInputs call: {0:?}")] + FailedToDecodeGetCallInputsCall(#[from] alloy_sol_types::Error), } /// Returns the call inputs of a transaction. @@ -23,19 +25,11 @@ pub fn get_call_inputs( inputs: &CallInputs, context: &PhEvmContext, ) -> Result { - // Skip function selector (4 bytes) - let input_data = &inputs.input[4..]; + let get_call_inputs = getCallInputsCall::abi_decode(&inputs.input, true)?; - // Input must be at least 64 bytes (2 * 32 byte parameters) - if input_data.len() < 64 { - return Err(GetCallInputsError::InvalidInputLength(input_data.len())); - } - - // Extract address from first parameter (skip 12 bytes padding) - let target = Address::from_slice(&input_data[12..32]); + let target = get_call_inputs.target; + let selector = get_call_inputs.selector; - // Extract selector from second parameter (skip 28 bytes padding) - let selector = FixedBytes::from_slice(&input_data[32..36]); let binding = Vec::new(); let call_inputs = context @@ -45,11 +39,16 @@ pub fn get_call_inputs( .get(&(target, selector)) .unwrap_or(&binding); - let sol_call_inputs: Vec = call_inputs + let sol_call_inputs = call_inputs .iter() .map(|input| { - PhEvm::CallInputs { - input: input.input.clone(), + let original_input_data = input.input.clone(); + let input_data_wo_selector = match original_input_data.len() >= 4 { + true => original_input_data.slice(4..), + false => Bytes::new(), + }; + PhEvmCallInputs { + input: input_data_wo_selector, gas_limit: input.gas_limit, bytecode_address: input.bytecode_address, target_address: input.target_address, @@ -57,20 +56,218 @@ pub fn get_call_inputs( value: input.value.get(), } }) - .collect(); + .collect::>(); let encoded: Bytes = - >::abi_encode(&sol_call_inputs).into(); + >::abi_encode(&sol_call_inputs).into(); Ok(encoded) } #[cfg(test)] mod test { - use crate::test_utils::run_precompile_test; + use super::*; + use crate::{ + inspectors::{ + phevm::{ + LogsAndTraces, + PhEvmContext, + }, + tracer::CallTracer, + }, + test_utils::{ + random_address, + random_bytes, + random_selector, + random_u256, + run_precompile_test, + }, + }; + + fn test_with_inputs_and_tracer( + call_inputs: &CallInputs, + call_tracer: CallTracer, + ) -> Result { + let logs_and_traces = LogsAndTraces { + tx_logs: &[], + call_traces: &call_tracer, + }; + let context = PhEvmContext { + logs_and_traces: &logs_and_traces, + adopter: Address::ZERO, + }; + get_call_inputs(call_inputs, &context) + } + use alloy_primitives::{ + Address, + Bytes, + FixedBytes, + U256, + }; + use revm::interpreter::{ + CallInputs, + CallScheme, + CallValue, + }; + + fn create_get_call_input(target: Address, selector: FixedBytes<4>) -> CallInputs { + let get_call_inputs = getCallInputsCall { target, selector }; + + let input_data = get_call_inputs.abi_encode(); + + CallInputs { + input: Bytes::from(input_data), + gas_limit: 1_000_000, + bytecode_address: Address::ZERO, + target_address: Address::ZERO, + caller: Address::ZERO, + value: CallValue::Transfer(U256::ZERO), + scheme: CallScheme::Call, + is_static: false, + is_eof: false, + return_memory_offset: 0..0, + } + } + + fn create_random_call_input( + target: Address, + selector: FixedBytes<4>, + ) -> CallInputs { + let input_data = [&selector[..], &random_bytes::()[..]].concat(); + CallInputs { + input: Bytes::from(input_data), + gas_limit: 100_000, + bytecode_address: random_address(), + target_address: target, + caller: random_address(), + value: CallValue::Transfer(random_u256()), + scheme: CallScheme::Call, + is_static: false, + is_eof: false, + return_memory_offset: 0..0, + } + } + + fn create_call_tracer_with_inputs(call_inputs: I) -> CallTracer + where + I: IntoIterator, + { + let mut call_tracer = CallTracer::new(); + for input in call_inputs { + call_tracer.record_call(input); + } + call_tracer + } + + #[test] + fn test_get_call_inputs_success() { + let target = random_address(); + let selector = random_selector(); + + // Set up the context + let mock_call_input = create_random_call_input::<32>(target, selector); + + let get_call_inputs = create_get_call_input(target, selector); + let call_tracer = create_call_tracer_with_inputs(vec![mock_call_input.clone()]); + + let result = test_with_inputs_and_tracer(&get_call_inputs, call_tracer); + + let encoded = result.unwrap(); + + // Verify we can decode the result + let decoded = + >::abi_decode(&encoded, false); + + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 1); + + assert_eq!( + decoded_array[0].target_address, + mock_call_input.target_address + ); + + assert_eq!(decoded_array[0].input, mock_call_input.input.slice(4..)); + } + + #[test] + fn test_get_call_inputs_empty_result() { + let target = random_address(); + let selector = random_selector(); + + let get_call_inputs = create_get_call_input(target, selector); + + // Create context with no matching call inputs (different target and selector) + let call_tracer = CallTracer::new(); + + let result = test_with_inputs_and_tracer(&get_call_inputs, call_tracer); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + + // Should return empty array + let decoded = + >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 0); + } + + #[test] + fn test_get_call_inputs_invalid_input_length() { + let target = random_address(); + let selector = random_selector(); + + let mut get_call_inputs = create_get_call_input(target, selector); + get_call_inputs.input = Bytes::from(random_bytes::<32>()); + + let call_tracer = CallTracer::new(); + + let result = test_with_inputs_and_tracer(&get_call_inputs, call_tracer); + assert!(matches!( + result, + Err(GetCallInputsError::FailedToDecodeGetCallInputsCall(_)) + )); + } + + #[test] + fn test_get_call_inputs_multiple_results() { + let target = random_address(); + let selector = random_selector(); + + let get_call_inputs = create_get_call_input(target, selector); + + // Set up context with multiple call inputs + let mock_call_inputs = vec![ + create_random_call_input::<32>(target, selector), + create_random_call_input::<64>(target, selector), + ]; + + let call_tracer = create_call_tracer_with_inputs(mock_call_inputs.clone()); + + let result = test_with_inputs_and_tracer(&get_call_inputs, call_tracer); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + let decoded = + >::abi_decode(&encoded, true) + .unwrap(); + assert_eq!(decoded.len(), 2); + + // Verify both results are present + assert_eq!( + decoded[0].target_address, + mock_call_inputs[0].target_address + ); + assert_eq!(decoded[0].input, mock_call_inputs[0].input.slice(4..)); + assert_eq!( + decoded[1].target_address, + mock_call_inputs[1].target_address + ); + assert_eq!(decoded[1].input, mock_call_inputs[1].input.slice(4..)); + } #[tokio::test] - async fn test_get_call_inputs() { + async fn test_get_call_inputs_integration() { let result = run_precompile_test("TestGetCallInputs").await; assert!(result.is_valid()); let result_and_state = result.result_and_state; diff --git a/src/inspectors/precompiles/fork.rs b/src/inspectors/precompiles/fork.rs index 336c366..60efc96 100644 --- a/src/inspectors/precompiles/fork.rs +++ b/src/inspectors/precompiles/fork.rs @@ -48,10 +48,242 @@ pub fn fork_post_state( #[cfg(test)] mod test { - use crate::test_utils::run_precompile_test; + use super::*; + use crate::{ + db::{ + overlay::test_utils::MockDb, + MultiForkDb, + }, + primitives::JournaledState, + test_utils::{ + random_address, + random_u256, + run_precompile_test, + }, + }; + use alloy_primitives::{ + Address, + U256, + }; + use revm::{ + primitives::{ + AccountInfo, + BlockEnv, + CfgEnv, + Env, + SpecId, + TxEnv, + KECCAK_EMPTY, + }, + DatabaseRef, + EvmContext, + }; + use std::collections::HashSet; + + fn create_test_context_with_mock_db( + pre_tx_storage: Vec<(Address, U256, U256)>, + post_tx_storage: Vec<(Address, U256, U256)>, + ) -> (MultiForkDb, JournaledState) { + let mut pre_tx_db = MockDb::new(); + let mut post_tx_db = MockDb::new(); + + // Set up pre-tx state + for (address, slot, value) in pre_tx_storage { + pre_tx_db.insert_storage(address, slot, value); + pre_tx_db.insert_account( + address, + AccountInfo { + balance: U256::ZERO, + nonce: 0, + code_hash: KECCAK_EMPTY, + code: None, + }, + ); + } + + // Set up post-tx state + for (address, slot, value) in post_tx_storage { + post_tx_db.insert_storage(address, slot, value); + post_tx_db.insert_account( + address, + AccountInfo { + balance: U256::ZERO, + nonce: 0, + code_hash: KECCAK_EMPTY, + code: None, + }, + ); + } + + let multi_fork_db = MultiForkDb::new(pre_tx_db, post_tx_db); + let journaled_state = JournaledState::new(SpecId::LATEST, HashSet::default()); + + (multi_fork_db, journaled_state) + } + + #[test] + fn test_fork_pre_state_with_mock_db() { + let address = random_address(); + let slot = random_u256(); + let pre_value = U256::from(100); + let post_value = U256::from(200); + + let (mut multi_fork_db, journaled_state) = create_test_context_with_mock_db( + vec![(address, slot, pre_value)], + vec![(address, slot, post_value)], + ); + + let init_journaled_state = journaled_state.clone(); + + // Create EvmContext + let env = Box::new(Env { + cfg: CfgEnv::default(), + block: BlockEnv::default(), + tx: TxEnv::default(), + }); + + let mut context = EvmContext::new_with_env(&mut multi_fork_db, env); + context.inner.journaled_state = journaled_state; + + // Test fork_pre_state function + let result = fork_pre_state(&init_journaled_state, &mut context); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Bytes::default()); + + // Verify that we're now on the pre-tx fork + let storage_value = context.db.storage_ref(address, slot).unwrap(); + assert_eq!(storage_value, pre_value); + } + + #[test] + fn test_fork_post_state_with_mock_db() { + let address = random_address(); + let slot = random_u256(); + let pre_value = U256::from(300); + let post_value = U256::from(400); + + let (mut multi_fork_db, journaled_state) = create_test_context_with_mock_db( + vec![(address, slot, pre_value)], + vec![(address, slot, post_value)], + ); + + let init_journaled_state = journaled_state.clone(); + + // Create EvmContext + let env = Box::new(Env { + cfg: CfgEnv::default(), + block: BlockEnv::default(), + tx: TxEnv::default(), + }); + + let mut context = EvmContext::new_with_env(&mut multi_fork_db, env); + context.inner.journaled_state = journaled_state; + + // Test fork_post_state function + let result = fork_post_state(&init_journaled_state, &mut context); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), Bytes::default()); + + // Verify that we're now on the post-tx fork + let storage_value = context.db.storage_ref(address, slot).unwrap(); + assert_eq!(storage_value, post_value); + } + + #[test] + fn test_fork_switching_between_states() { + let address = random_address(); + let slot = random_u256(); + let pre_value = U256::from(500); + let post_value = U256::from(600); + + let (mut multi_fork_db, journaled_state) = create_test_context_with_mock_db( + vec![(address, slot, pre_value)], + vec![(address, slot, post_value)], + ); + + let init_journaled_state = journaled_state.clone(); + + // Create EvmContext + let env = Box::new(Env { + cfg: CfgEnv::default(), + block: BlockEnv::default(), + tx: TxEnv::default(), + }); + + let mut context = EvmContext::new_with_env(&mut multi_fork_db, env); + context.inner.journaled_state = journaled_state; + + // Start with pre-tx state + let result = fork_pre_state(&init_journaled_state, &mut context); + assert!(result.is_ok()); + let storage_value = context.db.storage_ref(address, slot).unwrap(); + assert_eq!(storage_value, pre_value); + + // Switch to post-tx state + let result = fork_post_state(&init_journaled_state, &mut context); + assert!(result.is_ok()); + let storage_value = context.db.storage_ref(address, slot).unwrap(); + assert_eq!(storage_value, post_value); + + // Switch back to pre-tx state + let result = fork_pre_state(&init_journaled_state, &mut context); + assert!(result.is_ok()); + let storage_value = context.db.storage_ref(address, slot).unwrap(); + assert_eq!(storage_value, pre_value); + } + + #[test] + fn test_fork_with_multiple_accounts_and_storage() { + let address1 = random_address(); + let address2 = random_address(); + let slot1 = random_u256(); + let slot2 = random_u256(); + + let pre_values = vec![ + (address1, slot1, U256::from(10)), + (address2, slot2, U256::from(20)), + ]; + let post_values = vec![ + (address1, slot1, U256::from(30)), + (address2, slot2, U256::from(40)), + ]; + + let (mut multi_fork_db, journaled_state) = + create_test_context_with_mock_db(pre_values.clone(), post_values.clone()); + + let init_journaled_state = journaled_state.clone(); + + // Create EvmContext + let env = Box::new(Env { + cfg: CfgEnv::default(), + block: BlockEnv::default(), + tx: TxEnv::default(), + }); + + let mut context = EvmContext::new_with_env(&mut multi_fork_db, env); + context.inner.journaled_state = journaled_state; + + // Test pre-tx state + let result = fork_pre_state(&init_journaled_state, &mut context); + assert!(result.is_ok()); + + let storage_value1 = context.db.storage_ref(address1, slot1).unwrap(); + let storage_value2 = context.db.storage_ref(address2, slot2).unwrap(); + assert_eq!(storage_value1, U256::from(10)); + assert_eq!(storage_value2, U256::from(20)); + + // Test post-tx state + let result = fork_post_state(&init_journaled_state, &mut context); + assert!(result.is_ok()); + + let storage_value1 = context.db.storage_ref(address1, slot1).unwrap(); + let storage_value2 = context.db.storage_ref(address2, slot2).unwrap(); + assert_eq!(storage_value1, U256::from(30)); + assert_eq!(storage_value2, U256::from(40)); + } #[tokio::test] - async fn test_fork_switching() { + async fn test_fork_integration() { let result = run_precompile_test("TestFork").await; assert!(result.is_valid()); let result_and_state = result.result_and_state; diff --git a/src/inspectors/precompiles/load.rs b/src/inspectors/precompiles/load.rs index 6945c75..a83833d 100644 --- a/src/inspectors/precompiles/load.rs +++ b/src/inspectors/precompiles/load.rs @@ -41,10 +41,159 @@ pub fn load_external_slot( #[cfg(test)] mod test { - use crate::test_utils::run_precompile_test; + use crate::{ + db::overlay::test_utils::MockDb, + inspectors::sol_primitives::PhEvm::loadCall, + primitives::{ + Bytecode, + FixedBytes, + }, + test_utils::{ + random_address, + random_u256, + run_precompile_test, + }, + }; + use alloy_primitives::{ + Address, + Bytes, + U256, + }; + use alloy_sol_types::SolCall; + use revm::{ + interpreter::{ + CallInputs, + CallScheme, + CallValue, + }, + primitives::{ + AccountInfo, + KECCAK_EMPTY, + }, + InnerEvmContext, + }; + + use super::*; + + fn create_call_inputs_for_load(target: Address, slot: U256) -> CallInputs { + let call = loadCall { + target, + slot: slot.into(), + }; + let encoded = call.abi_encode(); + + CallInputs { + input: Bytes::from(encoded), + gas_limit: 1_000_000, + bytecode_address: Address::ZERO, + target_address: target, + caller: Address::ZERO, + value: CallValue::Transfer(U256::ZERO), + scheme: CallScheme::Call, + is_static: false, + is_eof: false, + return_memory_offset: 0..0, + } + } + + fn create_mock_db_with_storage(address: Address, slot: U256, value: U256) -> MockDb { + let mut mock_db = MockDb::new(); + mock_db.insert_storage(address, slot, value); + mock_db.insert_account( + address, + AccountInfo { + balance: U256::ZERO, + nonce: 0, + code_hash: KECCAK_EMPTY, + code: Some(Bytecode::default()), + }, + ); + mock_db + } + + #[test] + fn test_load_call_encoding_roundtrip() { + let target = random_address(); + let slot = random_u256(); + let slot_bytes: alloy_primitives::FixedBytes<32> = slot.into(); + + let call = loadCall { + target, + slot: slot_bytes, + }; + let encoded = call.abi_encode(); + let decoded = loadCall::abi_decode(&encoded, true).unwrap(); + + assert_eq!(decoded.target, target); + assert_eq!(decoded.slot, slot_bytes); + } + + #[test] + fn test_load_call_existing_account_and_storage_value() { + // Test that function returns properly encoded storage value + let target = random_address(); + let slot = random_u256(); + let expected_value = random_u256(); + + // Create valid call inputs + let call_inputs = create_call_inputs_for_load(target, slot); + + // Create context with storage value + let mock_db = create_mock_db_with_storage(target, slot, expected_value); + let mut multi_fork = MultiForkDb::new(MockDb::new(), mock_db); + let context = InnerEvmContext::new(&mut multi_fork); + + let result = load_external_slot(&context, &call_inputs); + let decoded = loadCall::abi_decode_returns(&result.unwrap(), true).unwrap(); + assert_eq!( + decoded.data.0, + FixedBytes::from(expected_value.to_be_bytes()) + ); + } + + #[test] + fn test_load_call_existing_account_no_storage_value() { + // Test with an existing account but no storage value for the requested slot + let target = random_address(); + let slot = random_u256(); + + // Create valid call inputs + let call_inputs = create_call_inputs_for_load(target, slot); + + // Create context with account but no storage for this slot + let mock_db = create_mock_db_with_storage(target, random_u256(), random_u256()); + let mut multi_fork = MultiForkDb::new(MockDb::new(), mock_db); + let context = InnerEvmContext::new(&mut multi_fork); + + let result = load_external_slot(&context, &call_inputs); + let decoded = loadCall::abi_decode_returns(&result.unwrap(), true).unwrap(); + assert_eq!(decoded.data.0, FixedBytes::ZERO); + } + + #[test] + fn test_load_call_non_existing_account() { + // Test with a non-existing account + let target = random_address(); + let slot = random_u256(); + + // Create valid call inputs + let call_inputs = create_call_inputs_for_load(target, slot); + + // Create context with no accounts + let mock_db = create_mock_db_with_storage(Address::ZERO, random_u256(), random_u256()); + + let mut multi_fork = MultiForkDb::new(MockDb::new(), mock_db); + let context = InnerEvmContext::new(&mut multi_fork); + + let result = load_external_slot(&context, &call_inputs); + + // parse the result + let decoded = loadCall::abi_decode_returns(&result.unwrap(), true).unwrap(); + assert_eq!(decoded.data.0, FixedBytes::ZERO); + } #[tokio::test] - async fn test_get_storage() { + async fn test_load_integration() { let result = run_precompile_test("TestLoad").await; assert!(result.is_valid()); let result_and_state = result.result_and_state; diff --git a/src/inspectors/precompiles/logs.rs b/src/inspectors/precompiles/logs.rs index f6ce620..3b9c9f3 100644 --- a/src/inspectors/precompiles/logs.rs +++ b/src/inspectors/precompiles/logs.rs @@ -30,10 +30,272 @@ pub fn get_logs(context: &PhEvmContext) -> Result { #[cfg(test)] mod test { - use crate::test_utils::run_precompile_test; + use super::*; + use crate::{ + inspectors::{ + phevm::{ + LogsAndTraces, + PhEvmContext, + }, + sol_primitives::PhEvm, + tracer::CallTracer, + }, + test_utils::{ + random_address, + random_bytes, + random_bytes32, + run_precompile_test, + }, + }; + use alloy_primitives::{ + Address, + Bytes, + Log, + LogData, + }; + + fn with_logs_context(logs: Vec, f: F) -> R + where + F: FnOnce(&PhEvmContext) -> R, + { + let call_tracer = CallTracer::new(); + let logs_and_traces = LogsAndTraces { + tx_logs: &logs, + call_traces: &call_tracer, + }; + let context = PhEvmContext { + logs_and_traces: &logs_and_traces, + adopter: Address::ZERO, + }; + f(&context) + } + + #[test] + fn test_get_logs_empty() { + let result = with_logs_context(vec![], get_logs); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + assert!(!encoded.is_empty()); + + // Verify we can decode the result + let decoded = >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 0); + } + + #[test] + fn test_get_logs_single_log() { + let address = random_address(); + let topic = random_bytes32(); + let data = random_bytes::<64>(); + + let log = Log { + address, + data: LogData::new(vec![topic], Bytes::from(data)).unwrap(), + }; + + let result = with_logs_context(vec![log.clone()], get_logs); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + + // Verify we can decode the result + let decoded = >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 1); + + let decoded_log = &decoded_array[0]; + assert_eq!(decoded_log.emitter, address); + assert_eq!(decoded_log.topics.len(), 1); + assert_eq!(decoded_log.topics[0], topic); + assert_eq!(decoded_log.data, Bytes::from(data)); + } + + #[test] + fn test_get_logs_multiple_logs() { + let address1 = random_address(); + let address2 = random_address(); + let topic1 = random_bytes32(); + let topic2 = random_bytes32(); + let data1 = random_bytes::<32>(); + let data2 = random_bytes::<64>(); + + let logs = vec![ + Log { + address: address1, + data: LogData::new(vec![topic1], Bytes::from(data1)).unwrap(), + }, + Log { + address: address2, + data: LogData::new(vec![topic2], Bytes::from(data2)).unwrap(), + }, + ]; + + let result = with_logs_context(logs, get_logs); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + + // Verify we can decode the result + let decoded = >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 2); + + // Verify first log + let decoded_log1 = &decoded_array[0]; + assert_eq!(decoded_log1.emitter, address1); + assert_eq!(decoded_log1.topics.len(), 1); + assert_eq!(decoded_log1.topics[0], topic1); + assert_eq!(decoded_log1.data, Bytes::from(data1)); + + // Verify second log + let decoded_log2 = &decoded_array[1]; + assert_eq!(decoded_log2.emitter, address2); + assert_eq!(decoded_log2.topics.len(), 1); + assert_eq!(decoded_log2.topics[0], topic2); + assert_eq!(decoded_log2.data, Bytes::from(data2)); + } + + #[test] + fn test_get_logs_multiple_topics() { + let address = random_address(); + let topic1 = random_bytes32(); + let topic2 = random_bytes32(); + let topic3 = random_bytes32(); + let data = random_bytes::<128>(); + + let log = Log { + address, + data: LogData::new(vec![topic1, topic2, topic3], Bytes::from(data)).unwrap(), + }; + + let result = with_logs_context(vec![log], get_logs); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + + // Verify we can decode the result + let decoded = >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 1); + + let decoded_log = &decoded_array[0]; + assert_eq!(decoded_log.emitter, address); + assert_eq!(decoded_log.topics.len(), 3); + assert_eq!(decoded_log.topics[0], topic1); + assert_eq!(decoded_log.topics[1], topic2); + assert_eq!(decoded_log.topics[2], topic3); + assert_eq!(decoded_log.data, Bytes::from(data)); + } + + #[test] + fn test_get_logs_no_topics() { + let address = random_address(); + let data = random_bytes::<32>(); + + let log = Log { + address, + data: LogData::new(vec![], Bytes::from(data)).unwrap(), + }; + + let result = with_logs_context(vec![log], get_logs); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + + // Verify we can decode the result + let decoded = >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 1); + + let decoded_log = &decoded_array[0]; + assert_eq!(decoded_log.emitter, address); + assert_eq!(decoded_log.topics.len(), 0); + assert_eq!(decoded_log.data, Bytes::from(data)); + } + + #[test] + fn test_get_logs_empty_data() { + let address = random_address(); + let topic = random_bytes32(); + + let log = Log { + address, + data: LogData::new(vec![topic], Bytes::new()).unwrap(), + }; + + let result = with_logs_context(vec![log], get_logs); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + + // Verify we can decode the result + let decoded = >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 1); + + let decoded_log = &decoded_array[0]; + assert_eq!(decoded_log.emitter, address); + assert_eq!(decoded_log.topics.len(), 1); + assert_eq!(decoded_log.topics[0], topic); + assert_eq!(decoded_log.data, Bytes::new()); + } + + #[test] + fn test_get_logs_large_data() { + let address = random_address(); + let topic = random_bytes32(); + let large_data = random_bytes::<1024>(); // 1KB of data + + let log = Log { + address, + data: LogData::new(vec![topic], Bytes::from(large_data)).unwrap(), + }; + + let result = with_logs_context(vec![log], get_logs); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + + // Verify we can decode the result + let decoded = >::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + let decoded_array = decoded.unwrap(); + assert_eq!(decoded_array.len(), 1); + + let decoded_log = &decoded_array[0]; + assert_eq!(decoded_log.emitter, address); + assert_eq!(decoded_log.data, Bytes::from(large_data)); + } + + #[test] + fn test_get_logs_never_fails() { + // The function signature indicates it returns Result + // This means it should never fail, so let's verify that with edge cases + + let test_cases = vec![ + vec![], + vec![Log { + address: Address::ZERO, + data: LogData::new(vec![], Bytes::new()).unwrap(), + }], + ]; + + for logs in test_cases { + let result = with_logs_context(logs, get_logs); + assert!(result.is_ok(), "get_logs should never fail"); + } + } #[tokio::test] - async fn test_get_logs() { + async fn test_get_logs_integration() { let result = run_precompile_test("TestGetLogs").await; assert!(result.is_valid()); let result_and_state = result.result_and_state; diff --git a/src/inspectors/precompiles/state_changes.rs b/src/inspectors/precompiles/state_changes.rs index 9890673..2dd6d55 100644 --- a/src/inspectors/precompiles/state_changes.rs +++ b/src/inspectors/precompiles/state_changes.rs @@ -93,10 +93,401 @@ fn get_differences( #[cfg(test)] mod test { - use crate::test_utils::run_precompile_test; + use super::*; + use crate::{ + inspectors::{ + phevm::{ + LogsAndTraces, + PhEvmContext, + }, + sol_primitives::PhEvm, + tracer::CallTracer, + }, + primitives::{ + Account, + JournaledState, + }, + test_utils::{ + random_address, + random_bytes, + random_u256, + run_precompile_test, + }, + }; + use alloy_primitives::{ + Address, + Bytes, + FixedBytes, + U256, + }; + use alloy_sol_types::{ + SolCall, + SolValue, + }; + use revm::{ + interpreter::{ + CallInputs, + CallScheme, + CallValue, + }, + JournalEntry, + }; + use std::collections::HashSet; + + fn create_call_inputs_for_state_changes(contract_address: Address, slot: U256) -> CallInputs { + let call = PhEvm::getStateChangesCall { + contractAddress: contract_address, + slot: slot.into(), + }; + let encoded = call.abi_encode(); + + CallInputs { + input: Bytes::from(encoded), + gas_limit: 1_000_000, + bytecode_address: Address::ZERO, + target_address: contract_address, + caller: Address::ZERO, + value: CallValue::Transfer(U256::ZERO), + scheme: CallScheme::Call, + is_static: false, + is_eof: false, + return_memory_offset: 0..0, + } + } + + fn with_journaled_state_context(journaled_state: Option, f: F) -> R + where + F: FnOnce(&PhEvmContext) -> R, + { + let mut call_tracer = CallTracer::new(); + call_tracer.journaled_state = journaled_state; + + let logs_and_traces = LogsAndTraces { + tx_logs: &[], + call_traces: &call_tracer, + }; + + let context = PhEvmContext { + logs_and_traces: &logs_and_traces, + adopter: Address::ZERO, + }; + f(&context) + } + + fn create_journaled_state_with_changes( + address: Address, + slot: U256, + old_values: Vec, + current_value: U256, + ) -> JournaledState { + let mut journaled_state = + JournaledState::new(revm::primitives::SpecId::LATEST, HashSet::default()); + + // Add journal entries for storage changes + let mut journal_entries = Vec::new(); + for old_value in old_values { + journal_entries.push(JournalEntry::StorageChanged { + address, + key: slot, + had_value: old_value, + }); + } + journaled_state.journal = vec![journal_entries]; + + // Add the account to state with current storage value + let mut storage = std::collections::HashMap::default(); + storage.insert(slot, revm::primitives::EvmStorageSlot::new(current_value)); + + let account = Account { + info: revm::primitives::AccountInfo::default(), + storage, + status: revm::primitives::AccountStatus::Loaded, + }; + + journaled_state.state.insert(address, account); + + journaled_state + } + + #[test] + fn test_get_state_changes_success() { + let contract_address = random_address(); + let slot = random_u256(); + let old_values = vec![U256::from(100), U256::from(200)]; + let current_value = U256::from(300); + + let call_inputs = create_call_inputs_for_state_changes(contract_address, slot); + let journaled_state = create_journaled_state_with_changes( + contract_address, + slot, + old_values.clone(), + current_value, + ); + let result = with_journaled_state_context(Some(journaled_state), |context| { + get_state_changes(&call_inputs, context) + }); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + let decoded = Vec::::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + + let differences = decoded.unwrap(); + assert_eq!(differences.len(), 3); // 2 old values + 1 current value + assert_eq!(differences[0], U256::from(100)); + assert_eq!(differences[1], U256::from(200)); + assert_eq!(differences[2], current_value); + } + + #[test] + fn test_get_state_changes_no_changes() { + let contract_address = random_address(); + let slot = random_u256(); + + let call_inputs = create_call_inputs_for_state_changes(contract_address, slot); + + // Create empty journaled state with no changes + let journaled_state = + JournaledState::new(revm::primitives::SpecId::LATEST, HashSet::default()); + let result = with_journaled_state_context(Some(journaled_state), |context| { + get_state_changes(&call_inputs, context) + }); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + let decoded = Vec::::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + + let differences = decoded.unwrap(); + assert_eq!(differences.len(), 0); // No changes found + } + + #[test] + fn test_get_state_changes_missing_journaled_state() { + let contract_address = random_address(); + let slot = random_u256(); + + let call_inputs = create_call_inputs_for_state_changes(contract_address, slot); + let result = + with_journaled_state_context(None, |context| get_state_changes(&call_inputs, context)); + assert!(result.is_err()); + + match result.unwrap_err() { + GetStateChangesError::JournaledStateMissing => {} + other => panic!("Expected JournaledStateMissing, got {other:?}"), + } + } + + #[test] + fn test_get_state_changes_invalid_input() { + let invalid_input = Bytes::from(random_bytes::<10>()); // Too short for proper ABI decoding + let call_inputs = CallInputs { + input: invalid_input, + gas_limit: 1_000_000, + bytecode_address: Address::ZERO, + target_address: Address::ZERO, + caller: Address::ZERO, + value: CallValue::Transfer(U256::ZERO), + scheme: CallScheme::Call, + is_static: false, + is_eof: false, + return_memory_offset: 0..0, + }; + + let journaled_state = + JournaledState::new(revm::primitives::SpecId::LATEST, HashSet::default()); + let result = with_journaled_state_context(Some(journaled_state), |context| { + get_state_changes(&call_inputs, context) + }); + assert!(result.is_err()); + + match result.unwrap_err() { + GetStateChangesError::CallDecodeError(_) => {} + other => panic!("Expected CallDecodeError, got {other:?}"), + } + } + + #[test] + fn test_get_state_changes_account_not_found() { + let contract_address = random_address(); + let slot = random_u256(); + + let call_inputs = create_call_inputs_for_state_changes(contract_address, slot); + + // Create journaled state with journal entries but no account in state + let mut journaled_state = + JournaledState::new(revm::primitives::SpecId::LATEST, HashSet::default()); + + // Add journal entry for storage change + let journal_entries = vec![JournalEntry::StorageChanged { + address: contract_address, + key: slot, + had_value: U256::from(100), + }]; + journaled_state.journal = vec![journal_entries]; + // Note: We don't add the account to state, which should cause AccountNotFound error + + let result = with_journaled_state_context(Some(journaled_state), |context| { + get_state_changes(&call_inputs, context) + }); + assert!(result.is_err()); + + match result.unwrap_err() { + GetStateChangesError::AccountNotFound => {} + other => panic!("Expected AccountNotFound, got {other:?}"), + } + } + + #[test] + fn test_get_state_changes_slot_not_found() { + let contract_address = random_address(); + let slot = random_u256(); + + let call_inputs = create_call_inputs_for_state_changes(contract_address, slot); + + // Create journaled state with journal entries and account but no slot in storage + let mut journaled_state = + JournaledState::new(revm::primitives::SpecId::LATEST, HashSet::default()); + + // Add journal entry for storage change + let journal_entries = vec![JournalEntry::StorageChanged { + address: contract_address, + key: slot, + had_value: U256::from(100), + }]; + journaled_state.journal = vec![journal_entries]; + + // Add account but without the requested slot + let account = Account { + info: revm::primitives::AccountInfo::default(), + storage: std::collections::HashMap::default(), // Empty storage - slot not found + status: revm::primitives::AccountStatus::Loaded, + }; + journaled_state.state.insert(contract_address, account); + + let result = with_journaled_state_context(Some(journaled_state), |context| { + get_state_changes(&call_inputs, context) + }); + assert!(result.is_err()); + + match result.unwrap_err() { + GetStateChangesError::SlotNotFound => {} + other => panic!("Expected SlotNotFound, got {other:?}"), + } + } + + #[test] + fn test_get_state_changes_multiple_changes_same_slot() { + let contract_address = random_address(); + let slot = random_u256(); + let old_values = vec![ + U256::from(10), + U256::from(20), + U256::from(30), + U256::from(40), + ]; + let current_value = U256::from(50); + + let call_inputs = create_call_inputs_for_state_changes(contract_address, slot); + let journaled_state = create_journaled_state_with_changes( + contract_address, + slot, + old_values.clone(), + current_value, + ); + let result = with_journaled_state_context(Some(journaled_state), |context| { + get_state_changes(&call_inputs, context) + }); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + let decoded = Vec::::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + + let differences = decoded.unwrap(); + assert_eq!(differences.len(), 5); // 4 old values + 1 current value + + for (i, &old_value) in old_values.iter().enumerate() { + assert_eq!(differences[i], old_value); + } + assert_eq!(differences[4], current_value); + } + + #[test] + fn test_get_state_changes_different_addresses() { + let contract_address = random_address(); + let other_address = random_address(); + let slot = random_u256(); + + let call_inputs = create_call_inputs_for_state_changes(contract_address, slot); + + // Create journaled state with changes to different address (should not match) + let mut journaled_state = + JournaledState::new(revm::primitives::SpecId::LATEST, HashSet::default()); + + // Add journal entry for different address + let journal_entries = vec![JournalEntry::StorageChanged { + address: other_address, // Different address + key: slot, + had_value: U256::from(50), + }]; + journaled_state.journal = vec![journal_entries]; + + let result = with_journaled_state_context(Some(journaled_state), |context| { + get_state_changes(&call_inputs, context) + }); + assert!(result.is_ok()); + + let encoded = result.unwrap(); + let decoded = Vec::::abi_decode(&encoded, false); + assert!(decoded.is_ok()); + + let differences = decoded.unwrap(); + assert_eq!(differences.len(), 0); // No matching changes + } + + #[test] + fn test_get_differences_function() { + let contract_address = random_address(); + let slot = random_u256(); + let old_values = vec![U256::from(1), U256::from(2)]; + let current_value = U256::from(3); + + let journaled_state = create_journaled_state_with_changes( + contract_address, + slot, + old_values.clone(), + current_value, + ); + + let result = get_differences(&journaled_state, contract_address, slot); + assert!(result.is_ok()); + + let differences = result.unwrap(); + assert_eq!(differences.len(), 3); + assert_eq!(differences[0], U256::from(1)); + assert_eq!(differences[1], U256::from(2)); + assert_eq!(differences[2], current_value); + } + + #[test] + fn test_abi_encoding_roundtrip() { + let contract_address = random_address(); + let slot = random_u256(); + + let call = PhEvm::getStateChangesCall { + contractAddress: contract_address, + slot: slot.into(), + }; + let encoded = call.abi_encode(); + let decoded = PhEvm::getStateChangesCall::abi_decode(&encoded, true).unwrap(); + + assert_eq!(decoded.contractAddress, contract_address); + assert_eq!(decoded.slot, FixedBytes::<32>::from(slot)); + } #[tokio::test] - async fn test_get_statechanges() { + async fn test_get_state_changes_integration() { let result = run_precompile_test("TestGetStateChanges").await; assert!(result.is_valid(), "{result:#?}"); let result_and_state = result.result_and_state; diff --git a/src/inspectors/sol_primitives.rs b/src/inspectors/sol_primitives.rs index bef205f..de0222e 100644 --- a/src/inspectors/sol_primitives.rs +++ b/src/inspectors/sol_primitives.rs @@ -3,6 +3,7 @@ use alloy_sol_types::sol; sol! { interface PhEvm { // An Ethereum log + #[derive(Debug)] struct Log { // The topics of the log, including the signature, if any. bytes32[] topics; diff --git a/src/inspectors/tracer.rs b/src/inspectors/tracer.rs index 84bbabe..241aade 100644 --- a/src/inspectors/tracer.rs +++ b/src/inspectors/tracer.rs @@ -37,7 +37,7 @@ impl CallTracer { Self::default() } - pub fn record_call(&mut self, mut inputs: CallInputs) { + pub fn record_call(&mut self, inputs: CallInputs) { // If the input is at least 4 bytes long, use the first 4 bytes as the selector // Otherwise, use 0x00000000 as the default selector // Note: It doesn't mean that the selector is a valid function selector of the target contract @@ -48,10 +48,6 @@ impl CallTracer { FixedBytes::default() // 0x00000000 for ETH transfers/no-input calls }; - if inputs.input.len() >= 4 { - inputs.input = revm::primitives::Bytes::from(inputs.input[4..].to_vec()); - } - self.call_inputs .entry((inputs.target_address, selector)) .or_default() @@ -284,4 +280,100 @@ mod test { expected_triggers_trigger_contract ); } + + #[test] + fn test_triggers_all_types() { + use crate::primitives::{ + JournalEntry, + JournaledState, + SpecId, + }; + use revm::primitives::HashSet as RevmHashSet; + + let mut tracer = CallTracer::new(); + let addr1 = address!("1111111111111111111111111111111111111111"); + let addr2 = address!("2222222222222222222222222222222222222222"); + let addr3 = address!("3333333333333333333333333333333333333333"); + + // Test Call triggers + let selector1 = FixedBytes::<4>::from([0x12, 0x34, 0x56, 0x78]); + let selector2 = FixedBytes::<4>::from([0xAB, 0xCD, 0xEF, 0x00]); + tracer.call_inputs.insert((addr1, selector1), vec![]); + tracer.call_inputs.insert((addr2, selector2), vec![]); + + // Test with journaled state for balance and storage changes + let mut journaled_state = JournaledState::new(SpecId::CANCUN, RevmHashSet::default()); + + // Add balance transfer (should create BalanceChange triggers) + let balance_entries = vec![JournalEntry::BalanceTransfer { + from: addr1, + to: addr2, + balance: U256::from(100), + }]; + + // Add storage changes (should create StorageChange triggers) + let storage_entries = vec![ + JournalEntry::StorageChanged { + address: addr2, + key: U256::from(1), + had_value: U256::from(0), + }, + JournalEntry::StorageChanged { + address: addr3, + key: U256::from(2), + had_value: U256::from(5), + }, + ]; + + journaled_state.journal.push(balance_entries); + journaled_state.journal.push(storage_entries); + tracer.journaled_state = Some(journaled_state); + + let triggers = tracer.triggers(); + + // Verify Call triggers + assert!(triggers[&addr1].contains(&TriggerType::Call { + trigger_selector: selector1 + })); + assert!(triggers[&addr2].contains(&TriggerType::Call { + trigger_selector: selector2 + })); + + // Verify BalanceChange triggers + assert!(triggers[&addr1].contains(&TriggerType::BalanceChange)); + assert!(triggers[&addr2].contains(&TriggerType::BalanceChange)); + + // Verify StorageChange triggers + assert!(triggers[&addr2].contains(&TriggerType::StorageChange { + trigger_slot: U256::from(1).into() + })); + assert!(triggers[&addr3].contains(&TriggerType::StorageChange { + trigger_slot: U256::from(2).into() + })); + + // Verify we have triggers for all expected addresses + assert_eq!(triggers.len(), 3); + assert!(triggers.contains_key(&addr1)); + assert!(triggers.contains_key(&addr2)); + assert!(triggers.contains_key(&addr3)); + } + + #[test] + fn test_triggers_no_journal_state() { + let mut tracer = CallTracer::new(); + let addr = address!("1111111111111111111111111111111111111111"); + let selector = FixedBytes::<4>::from([0x12, 0x34, 0x56, 0x78]); + + // Only call triggers, no journaled state + tracer.call_inputs.insert((addr, selector), vec![]); + + let triggers = tracer.triggers(); + + // Should only have call trigger + assert_eq!(triggers.len(), 1); + assert!(triggers[&addr].contains(&TriggerType::Call { + trigger_selector: selector + })); + assert_eq!(triggers[&addr].len(), 1); + } } diff --git a/src/inspectors/trigger_recorder.rs b/src/inspectors/trigger_recorder.rs index e65fd50..8e86d72 100644 --- a/src/inspectors/trigger_recorder.rs +++ b/src/inspectors/trigger_recorder.rs @@ -326,4 +326,120 @@ mod test { TriggerRecorder { triggers } ); } + + #[test] + fn test_all_trigger_types_manual() { + let mut recorder = TriggerRecorder::default(); + + let selector1 = fixed_bytes!("12345678"); + let selector2 = fixed_bytes!("87654321"); + let selector3 = fixed_bytes!("ABCDEFAB"); + let selector4 = fixed_bytes!("FEDCBAED"); + let selector5 = fixed_bytes!("11111111"); + + // Test all trigger types + recorder.add_trigger(TriggerType::AllCalls, selector1); + recorder.add_trigger( + TriggerType::Call { + trigger_selector: fixed_bytes!("AAAAAAAA"), + }, + selector2, + ); + recorder.add_trigger(TriggerType::AllStorageChanges, selector3); + recorder.add_trigger( + TriggerType::StorageChange { + trigger_slot: fixed_bytes!( + "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB" + ), + }, + selector4, + ); + recorder.add_trigger(TriggerType::BalanceChange, selector5); + + // Verify all triggers were recorded + assert_eq!(recorder.triggers.len(), 5); + + assert!(recorder.triggers[&TriggerType::AllCalls].contains(&selector1)); + assert!(recorder.triggers[&TriggerType::Call { + trigger_selector: fixed_bytes!("AAAAAAAA") + }] + .contains(&selector2)); + assert!(recorder.triggers[&TriggerType::AllStorageChanges].contains(&selector3)); + assert!(recorder.triggers[&TriggerType::StorageChange { + trigger_slot: fixed_bytes!( + "BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB" + ) + }] + .contains(&selector4)); + assert!(recorder.triggers[&TriggerType::BalanceChange].contains(&selector5)); + } + + #[test] + fn test_multiple_selectors_same_trigger() { + let mut recorder = TriggerRecorder::default(); + + let selector1 = fixed_bytes!("11111111"); + let selector2 = fixed_bytes!("22222222"); + let selector3 = fixed_bytes!("33333333"); + + // Add multiple selectors to the same trigger type + recorder.add_trigger(TriggerType::AllCalls, selector1); + recorder.add_trigger(TriggerType::AllCalls, selector2); + recorder.add_trigger(TriggerType::AllCalls, selector3); + + // Should have one trigger type with three selectors + assert_eq!(recorder.triggers.len(), 1); + assert_eq!(recorder.triggers[&TriggerType::AllCalls].len(), 3); + + assert!(recorder.triggers[&TriggerType::AllCalls].contains(&selector1)); + assert!(recorder.triggers[&TriggerType::AllCalls].contains(&selector2)); + assert!(recorder.triggers[&TriggerType::AllCalls].contains(&selector3)); + } + + #[test] + fn test_record_trigger_invalid_selector() { + let mut recorder = TriggerRecorder::default(); + + // Create invalid call inputs with wrong function selector + let call_inputs = CallInputs { + input: Bytes::from([0xFF, 0xFF, 0xFF, 0xFF]), // Invalid selector + gas_limit: 1000000, + target_address: TRIGGER_RECORDER, + bytecode_address: TRIGGER_RECORDER, + caller: Address::random(), + value: revm::interpreter::CallValue::Transfer(U256::ZERO), + scheme: revm::interpreter::CallScheme::Call, + is_static: false, + is_eof: false, + return_memory_offset: 0..0, + }; + + let result = recorder.record_trigger(&call_inputs); + assert!(matches!(result, Err(RecordError::FnSelectorNotFound))); + } + + #[test] + fn test_record_trigger_decode_error() { + let mut recorder = TriggerRecorder::default(); + + // Create call inputs with valid selector but invalid data + let mut invalid_data = ITriggerRecorder::registerCallTrigger_0Call::SELECTOR.to_vec(); + invalid_data.extend_from_slice(&[0xFF; 10]); // Add invalid data + + let call_inputs = CallInputs { + input: invalid_data.into(), + gas_limit: 1000000, + target_address: TRIGGER_RECORDER, + bytecode_address: TRIGGER_RECORDER, + caller: Address::random(), + value: revm::interpreter::CallValue::Transfer(U256::ZERO), + scheme: revm::interpreter::CallScheme::Call, + is_static: false, + is_eof: false, + return_memory_offset: 0..0, + }; + + let result = recorder.record_trigger(&call_inputs); + assert!(matches!(result, Err(RecordError::CallDecodeError(_)))); + } } diff --git a/src/store/assertion_contract_extractor.rs b/src/store/assertion_contract_extractor.rs index 3f92383..5da8059 100644 --- a/src/store/assertion_contract_extractor.rs +++ b/src/store/assertion_contract_extractor.rs @@ -241,3 +241,90 @@ fn test_endless_loop_constructor() { } } } + +#[test] +fn test_extract_all_trigger_types() { + use crate::{ + inspectors::TriggerType, + test_utils::*, + }; + + let config = ExecutorConfig::default(); + + // Test extraction from TriggerOnAny contract which should have AllCalls, AllStorageChanges, and BalanceChange + let (_, trigger_recorder_any) = + extract_assertion_contract(bytecode("TriggerOnAny.sol:TriggerOnAny"), &config).unwrap(); + + // Should have all three "Any" trigger types + assert!(trigger_recorder_any + .triggers + .contains_key(&TriggerType::AllCalls)); + assert!(trigger_recorder_any + .triggers + .contains_key(&TriggerType::AllStorageChanges)); + assert!(trigger_recorder_any + .triggers + .contains_key(&TriggerType::BalanceChange)); + + // Each trigger should have the DEADBEEF selector + let expected_selector = crate::primitives::fixed_bytes!("DEADBEEF"); + assert!(trigger_recorder_any.triggers[&TriggerType::AllCalls].contains(&expected_selector)); + assert!( + trigger_recorder_any.triggers[&TriggerType::AllStorageChanges].contains(&expected_selector) + ); + assert!(trigger_recorder_any.triggers[&TriggerType::BalanceChange].contains(&expected_selector)); + + // Test extraction from TriggerOnSpecific contract which should have specific Call and StorageChange triggers + let (_, trigger_recorder_specific) = + extract_assertion_contract(bytecode("TriggerOnSpecific.sol:TriggerOnSpecific"), &config) + .unwrap(); + + // Should have specific trigger types + let expected_call_trigger = TriggerType::Call { + trigger_selector: crate::primitives::fixed_bytes!("f18c388a"), + }; + let expected_storage_trigger = TriggerType::StorageChange { + trigger_slot: crate::primitives::fixed_bytes!( + "ccc4fa32c72b32fc1388e9b17cbcd9cb5939d52551871739e4c3415f4ee595a0" + ), + }; + + assert!(trigger_recorder_specific + .triggers + .contains_key(&expected_call_trigger)); + assert!(trigger_recorder_specific + .triggers + .contains_key(&expected_storage_trigger)); + + // Both should have the DEADBEEF selector + assert!(trigger_recorder_specific.triggers[&expected_call_trigger].contains(&expected_selector)); + assert!( + trigger_recorder_specific.triggers[&expected_storage_trigger].contains(&expected_selector) + ); +} + +#[test] +fn test_extract_no_triggers_error() { + use crate::test_utils::*; + + let config = ExecutorConfig::default(); + + // Test with a contract that doesn't register any triggers + // This would be a contract that has a triggers() function but doesn't call any register functions + let result = extract_assertion_contract( + bytecode("Target.sol:Target"), // Target contract likely doesn't register triggers + &config, + ); + + match result { + Ok(_) => panic!("Expected NoTriggersRecorded error"), + Err(FnSelectorExtractorError::NoTriggersRecorded) => { + // This is expected + } + Err(other) => { + // The Target contract might not even have a triggers() function, + // so we might get a different error, which is also acceptable for this test + println!("Got different error (acceptable): {other:?}"); + } + } +} diff --git a/src/store/assertion_store.rs b/src/store/assertion_store.rs index 5657173..c0da765 100644 --- a/src/store/assertion_store.rs +++ b/src/store/assertion_store.rs @@ -859,4 +859,226 @@ mod tests { Ok(()) } + + #[test] + fn test_all_trigger_types_comprehensive() -> Result<(), AssertionStoreError> { + let aa = Address::random(); + + // Create unique selectors for each trigger type + let selector_specific_call = FixedBytes::<4>::random(); + let selector_all_calls = FixedBytes::<4>::random(); + let selector_specific_storage = FixedBytes::<4>::random(); + let selector_all_storage = FixedBytes::<4>::random(); + let selector_balance = FixedBytes::<4>::random(); + + let trigger_selector = FixedBytes::<4>::from([0x12, 0x34, 0x56, 0x78]); + let trigger_slot = U256::from(42); + + // Create recorded triggers for ALL trigger types + let recorded_triggers = vec![ + ( + TriggerType::Call { trigger_selector }, + vec![selector_specific_call] + .into_iter() + .collect::>(), + ), + ( + TriggerType::AllCalls, + vec![selector_all_calls].into_iter().collect::>(), + ), + ( + TriggerType::StorageChange { + trigger_slot: trigger_slot.into(), + }, + vec![selector_specific_storage] + .into_iter() + .collect::>(), + ), + ( + TriggerType::AllStorageChanges, + vec![selector_all_storage] + .into_iter() + .collect::>(), + ), + ( + TriggerType::BalanceChange, + vec![selector_balance].into_iter().collect::>(), + ), + ]; + + // Create journal entries that trigger specific call, storage change, and balance change + let journal_entries = vec![ + JournalEntry::StorageChanged { + address: aa, + key: trigger_slot, + had_value: U256::from(0), + }, + JournalEntry::StorageChanged { + address: aa, + key: U256::from(99), // Different slot to trigger AllStorageChanges + had_value: U256::from(1), + }, + JournalEntry::BalanceTransfer { + from: aa, + to: Address::random(), + balance: U256::from(100), + }, + ]; + + let store = AssertionStore::new_ephemeral()?; + let mut trigger_recorder = TriggerRecorder::default(); + + recorded_triggers.iter().for_each(|(trigger, selectors)| { + trigger_recorder + .triggers + .insert(trigger.clone(), selectors.clone()); + }); + + let mut assertion = create_test_assertion(100, None); + assertion.trigger_recorder = trigger_recorder; + store.insert(aa, assertion)?; + + let mut tracer = CallTracer::default(); + // Add call with the specific trigger selector + tracer.call_inputs.insert((aa, trigger_selector), vec![]); + + tracer.journaled_state = Some(JournaledState::new( + SpecId::LONDON, + RevmHashSet::
::default(), + )); + + tracer + .journaled_state + .as_mut() + .unwrap() + .journal + .push(journal_entries); + + let assertions = store.read(&tracer, U256::from(100))?; + assert_eq!(assertions.len(), 1); + + // All selectors should be included since we triggered: + // - Call (specific trigger_selector) + // - AllCalls (because we had a call) + // - StorageChange (specific trigger_slot) + // - AllStorageChanges (because we had storage changes) + // - BalanceChange (because we had a balance transfer) + let mut expected_selectors = vec![ + selector_specific_call, + selector_all_calls, + selector_specific_storage, + selector_all_storage, + selector_balance, + ]; + expected_selectors.sort(); + + let mut matched_selectors = assertions[0].selectors.clone(); + matched_selectors.sort(); + assert_eq!(matched_selectors, expected_selectors); + + Ok(()) + } + + #[test] + fn test_no_matching_triggers() -> Result<(), AssertionStoreError> { + let aa = Address::random(); + + // Create triggers that won't be matched + let recorded_triggers = vec![ + ( + TriggerType::Call { + trigger_selector: FixedBytes::<4>::from([0xFF, 0xFF, 0xFF, 0xFF]), + }, + vec![FixedBytes::<4>::random()] + .into_iter() + .collect::>(), + ), + ( + TriggerType::StorageChange { + trigger_slot: U256::from(999).into(), + }, + vec![FixedBytes::<4>::random()] + .into_iter() + .collect::>(), + ), + ]; + + // Create journal entries that DON'T match the triggers + let journal_entries = vec![ + JournalEntry::StorageChanged { + address: aa, + key: U256::from(123), // Different slot + had_value: U256::from(0), + }, + JournalEntry::BalanceTransfer { + from: aa, + to: Address::random(), + balance: U256::from(100), + }, + ]; + + let assertions = setup_and_match(recorded_triggers, journal_entries, aa)?; + assert_eq!(assertions.len(), 1); + + // No selectors should match since triggers don't align + assert_eq!(assertions[0].selectors.len(), 0); + + Ok(()) + } + + #[test] + fn test_partial_trigger_matching() -> Result<(), AssertionStoreError> { + let aa = Address::random(); + + let selector1 = FixedBytes::<4>::random(); + let selector2 = FixedBytes::<4>::random(); + let selector3 = FixedBytes::<4>::random(); + + // Setup triggers where only some will match + let recorded_triggers = vec![ + ( + TriggerType::Call { + trigger_selector: FixedBytes::<4>::from([0x11, 0x22, 0x33, 0x44]), + }, + vec![selector1].into_iter().collect::>(), + ), + ( + TriggerType::StorageChange { + trigger_slot: U256::from(5).into(), + }, + vec![selector2].into_iter().collect::>(), + ), + ( + TriggerType::BalanceChange, + vec![selector3].into_iter().collect::>(), + ), + ]; + + // Only trigger storage change and balance change, not the specific call + let journal_entries = vec![ + JournalEntry::StorageChanged { + address: aa, + key: U256::from(5), // Matches the trigger + had_value: U256::from(0), + }, + JournalEntry::BalanceTransfer { + from: aa, + to: Address::random(), + balance: U256::from(100), + }, + ]; + + let assertions = setup_and_match(recorded_triggers, journal_entries, aa)?; + assert_eq!(assertions.len(), 1); + + // Should only have selectors for storage change and balance change + let mut expected_selectors = vec![selector2, selector3]; + expected_selectors.sort(); + + let mut matched_selectors = assertions[0].selectors.clone(); + matched_selectors.sort(); + assert_eq!(matched_selectors, expected_selectors); + + Ok(()) + } } diff --git a/src/test_utils.rs b/src/test_utils.rs index 7102ad4..7911feb 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -105,6 +105,22 @@ pub fn random_bytes() -> FixedBytes { FixedBytes::new(value) } +pub fn random_address() -> Address { + random_bytes::<20>().into() +} + +pub fn random_u256() -> U256 { + random_bytes::<32>().into() +} + +pub fn random_selector() -> FixedBytes<4> { + random_bytes::<4>() +} + +pub fn random_bytes32() -> FixedBytes<32> { + random_bytes::<32>() +} + fn read_artifact(input: &str) -> serde_json::Value { let mut parts = input.split(':'); let file_name = parts.next().expect("Failed to read filename");