diff --git a/core-primitives/settings/src/lib.rs b/core-primitives/settings/src/lib.rs index 536949f983..b945a1a955 100644 --- a/core-primitives/settings/src/lib.rs +++ b/core-primitives/settings/src/lib.rs @@ -48,7 +48,6 @@ pub mod files { // used by worker and enclave pub const SHARDS_PATH: &str = "shards"; - pub const ENCRYPTED_STATE_FILE: &str = "state.bin"; pub const LAST_SLOT_BIN: &str = "last_slot.bin"; #[cfg(feature = "production")] diff --git a/core-primitives/stf-state-handler/src/error.rs b/core-primitives/stf-state-handler/src/error.rs index e2db62301c..e283c657a8 100644 --- a/core-primitives/stf-state-handler/src/error.rs +++ b/core-primitives/stf-state-handler/src/error.rs @@ -51,6 +51,8 @@ pub enum Error { OsStringConversion, #[error("SGX crypto error: {0}")] CryptoError(itp_sgx_crypto::Error), + #[error("IO error: {0}")] + IO(std::io::Error), #[error("SGX error, status: {0}")] SgxError(sgx_status_t), #[error(transparent)] @@ -59,7 +61,7 @@ pub enum Error { impl From for Error { fn from(e: std::io::Error) -> Self { - Self::Other(e.into()) + Self::IO(e) } } diff --git a/core-primitives/stf-state-handler/src/file_io.rs b/core-primitives/stf-state-handler/src/file_io.rs index e8d522a9f5..c0de994cb5 100644 --- a/core-primitives/stf-state-handler/src/file_io.rs +++ b/core-primitives/stf-state-handler/src/file_io.rs @@ -19,23 +19,97 @@ use crate::sgx_reexport_prelude::*; #[cfg(any(test, feature = "std"))] -use rust_base58::base58::ToBase58; +use rust_base58::base58::{FromBase58, ToBase58}; #[cfg(feature = "sgx")] -use base58::ToBase58; - -#[cfg(any(test, feature = "sgx"))] -use itp_settings::files::ENCRYPTED_STATE_FILE; +use base58::{FromBase58, ToBase58}; #[cfg(any(test, feature = "sgx"))] use std::string::String; use crate::{error::Result, state_snapshot_primitives::StateId}; -use codec::Encode; +use codec::{Decode, Encode}; +// Todo: Can be migrated to here in the course of #1292. use itp_settings::files::SHARDS_PATH; use itp_types::ShardIdentifier; use log::error; -use std::{format, path::PathBuf, vec::Vec}; +use std::{ + format, + path::{Path, PathBuf}, + vec::Vec, +}; + +/// File name of the encrypted state file. +/// +/// It is also the suffix of all past snapshots. +pub const ENCRYPTED_STATE_FILE: &str = "state.bin"; + +/// Helps with file system operations of all files relevant for the State. +#[derive(Debug, Clone, Eq, PartialEq)] +pub struct StateDir { + base_path: PathBuf, +} + +impl StateDir { + pub fn new(base_path: PathBuf) -> Self { + Self { base_path } + } + + pub fn shards_directory(&self) -> PathBuf { + self.base_path.join(SHARDS_PATH) + } + + pub fn shard_path(&self, shard: &ShardIdentifier) -> PathBuf { + self.shards_directory().join(shard.encode().to_base58()) + } + + pub fn list_shards(&self) -> Result> { + Ok(list_shards(&self.shards_directory()) + .map(|iter| iter.collect()) + // return an empty vec in case the directory does not exist. + .unwrap_or_default()) + } + + pub fn list_state_ids_for_shard( + &self, + shard_identifier: &ShardIdentifier, + ) -> Result> { + let shard_path = self.shard_path(shard_identifier); + Ok(state_ids_for_shard(shard_path.as_path())?.collect()) + } + + pub fn purge_shard_dir(&self, shard: &ShardIdentifier) { + let shard_dir_path = self.shard_path(shard); + if let Err(e) = std::fs::remove_dir_all(&shard_dir_path) { + error!("Failed to remove shard directory {:?}: {:?}", shard_dir_path, e); + } + } + + pub fn shard_exists(&self, shard: &ShardIdentifier) -> bool { + let shard_path = self.shard_path(shard); + shard_path.exists() && shard_contains_valid_state_id(&shard_path) + } + + pub fn create_shard(&self, shard: &ShardIdentifier) -> Result<()> { + Ok(std::fs::create_dir_all(self.shard_path(shard))?) + } + + pub fn state_file_path(&self, shard: &ShardIdentifier, state_id: StateId) -> PathBuf { + self.shard_path(shard).join(to_file_name(state_id)) + } + + pub fn file_for_state_exists(&self, shard: &ShardIdentifier, state_id: StateId) -> bool { + self.state_file_path(shard, state_id).exists() + } + + #[cfg(feature = "test")] + pub fn given_initialized_shard(&self, shard: &ShardIdentifier) { + if self.shard_exists(shard) { + self.purge_shard_dir(shard); + } + self.create_shard(&shard).unwrap() + } +} /// Trait to abstract file I/O for state. pub trait StateFileIo { @@ -91,10 +165,8 @@ pub trait StateFileIo { #[cfg(feature = "sgx")] pub mod sgx { - use super::*; use crate::error::Error; - use base58::FromBase58; use codec::Decode; use core::fmt::Debug; use itp_hashing::Hash; @@ -108,6 +180,7 @@ pub mod sgx { /// SGX state file I/O. pub struct SgxStateFileIo { state_key_repository: Arc, + state_dir: StateDir, _phantom: PhantomData, } @@ -117,8 +190,8 @@ pub mod sgx { ::KeyType: StateCrypto, State: SgxExternalitiesTrait, { - pub fn new(state_key_repository: Arc) -> Self { - SgxStateFileIo { state_key_repository, _phantom: PhantomData } + pub fn new(state_key_repository: Arc, state_dir: StateDir) -> Self { + SgxStateFileIo { state_key_repository, state_dir, _phantom: PhantomData } } fn read(&self, path: &Path) -> Result> { @@ -163,11 +236,11 @@ pub mod sgx { shard_identifier: &ShardIdentifier, state_id: StateId, ) -> Result { - if !file_for_state_exists(shard_identifier, state_id) { + if !self.state_dir.file_for_state_exists(shard_identifier, state_id) { return Err(Error::InvalidStateId(state_id)) } - let state_path = state_file_path(shard_identifier, state_id); + let state_path = self.state_dir.state_file_path(shard_identifier, state_id); trace!("loading state from: {:?}", state_path); let state_encoded = self.read(&state_path)?; @@ -203,7 +276,7 @@ pub mod sgx { state_id: StateId, state: &Self::StateType, ) -> Result { - init_shard(&shard_identifier)?; + self.state_dir.create_shard(&shard_identifier)?; self.write(shard_identifier, state_id, state) } @@ -215,7 +288,7 @@ pub mod sgx { state_id: StateId, state: &Self::StateType, ) -> Result { - let state_path = state_file_path(shard_identifier, state_id); + let state_path = self.state_dir.state_file_path(shard_identifier, state_id); trace!("writing state to: {:?}", state_path); // Only save the state, the state diff is pruned. @@ -229,114 +302,79 @@ pub mod sgx { } fn remove(&self, shard_identifier: &ShardIdentifier, state_id: StateId) -> Result<()> { - fs::remove_file(state_file_path(shard_identifier, state_id)) - .map_err(|e| Error::Other(e.into())) + Ok(fs::remove_file(self.state_dir.state_file_path(shard_identifier, state_id))?) } fn shard_exists(&self, shard_identifier: &ShardIdentifier) -> bool { - shard_exists(shard_identifier) + self.state_dir.shard_exists(shard_identifier) } fn list_shards(&self) -> Result> { - list_shards() + self.state_dir.list_shards() } - fn list_state_ids_for_shard( - &self, - shard_identifier: &ShardIdentifier, - ) -> Result> { - let shard_path = shard_path(shard_identifier); - let directory_items = list_items_in_directory(&shard_path); - - Ok(directory_items - .iter() - .flat_map(|item| { - let maybe_state_id = extract_state_id_from_file_name(item.as_str()); - if maybe_state_id.is_none() { - warn!("Found item ({}) that does not match state snapshot naming pattern, ignoring it", item) - } - maybe_state_id - }) - .collect()) + fn list_state_ids_for_shard(&self, shard: &ShardIdentifier) -> Result> { + self.state_dir.list_state_ids_for_shard(shard) } } +} - fn state_file_path(shard: &ShardIdentifier, state_id: StateId) -> PathBuf { - let mut shard_file_path = shard_path(shard); - shard_file_path.push(to_file_name(state_id)); - shard_file_path - } - - fn file_for_state_exists(shard: &ShardIdentifier, state_id: StateId) -> bool { - state_file_path(shard, state_id).exists() - } - - /// Returns true if a shard directory for a given identifier exists AND contains at least one state file. - pub(crate) fn shard_exists(shard: &ShardIdentifier) -> bool { - let shard_path = shard_path(shard); - if !shard_path.exists() { - return false +/// Lists all files with a valid state snapshot naming pattern. +pub(crate) fn state_ids_for_shard(shard_path: &Path) -> Result> { + Ok(items_in_directory(shard_path)?.filter_map(|item| { + match extract_state_id_from_file_name(&item) { + Some(state_id) => Some(state_id), + None => { + log::warn!( + "Found item ({}) that does not match state snapshot naming pattern, ignoring it", + item + ); + None + }, } + })) +} - shard_path - .read_dir() - // When the iterator over all files in the directory returns none, the directory is empty. - .map(|mut d| d.next().is_some()) - .unwrap_or(false) - } - - pub(crate) fn init_shard(shard: &ShardIdentifier) -> Result<()> { - let path = shard_path(shard); - fs::create_dir_all(path).map_err(|e| Error::Other(e.into())) - } - - /// List any valid shards that are found in the shard path. - /// Ignore any items (files, directories) that are not valid shard identifiers. - pub(crate) fn list_shards() -> Result> { - let directory_items = list_items_in_directory(&PathBuf::from(format!("./{}", SHARDS_PATH))); - Ok(directory_items - .iter() - .flat_map(|item| { - item.from_base58() - .ok() - .map(|encoded_shard_id| { - ShardIdentifier::decode(&mut encoded_shard_id.as_slice()).ok() - }) - .flatten() - }) - .collect()) - } - - fn list_items_in_directory(directory: &Path) -> Vec { - let items = match directory.read_dir() { - Ok(rd) => rd, - Err(_) => return Vec::new(), - }; +/// Returns an iterator over all valid shards in a directory. +/// +/// Ignore any items (files, directories) that are not valid shard identifiers. +pub(crate) fn list_shards(path: &Path) -> Result> { + Ok(items_in_directory(path)?.filter_map(|base58| match shard_from_base58(&base58) { + Ok(shard) => Some(shard), + Err(e) => { + error!("Found invalid shard ({}). Error: {:?}", base58, e); + None + }, + })) +} - items - .flat_map(|fr| fr.map(|de| de.file_name().into_string().ok()).ok().flatten()) - .collect() - } +fn shard_from_base58(base58: &str) -> Result { + let vec = base58.from_base58()?; + Ok(Decode::decode(&mut vec.as_slice())?) } -/// Remove a shard directory with all of its content. -pub fn purge_shard_dir(shard: &ShardIdentifier) { - let shard_dir_path = shard_path(shard); - if let Err(e) = std::fs::remove_dir_all(&shard_dir_path) { - error!("Failed to remove shard directory {:?}: {:?}", shard_dir_path, e); - } +/// Returns an iterator over all filenames in a directory. +fn items_in_directory(directory: &Path) -> Result> { + Ok(directory + .read_dir()? + .filter_map(|fr| fr.ok().and_then(|de| de.file_name().into_string().ok()))) } -pub(crate) fn shard_path(shard: &ShardIdentifier) -> PathBuf { - PathBuf::from(format!("./{}/{}", SHARDS_PATH, shard.encode().to_base58())) +fn shard_contains_valid_state_id(path: &Path) -> bool { + // If at least on item can be decoded into a state id, the shard is not empty. + match state_ids_for_shard(path) { + Ok(mut iter) => iter.next().is_some(), + Err(e) => { + error!("Error in reading shard dir: {:?}", e); + false + }, + } } -#[cfg(any(test, feature = "sgx"))] fn to_file_name(state_id: StateId) -> String { format!("{}_{}", state_id, ENCRYPTED_STATE_FILE) } -#[cfg(any(test, feature = "sgx"))] fn extract_state_id_from_file_name(file_name: &str) -> Option { let state_id_str = file_name.strip_suffix(format!("_{}", ENCRYPTED_STATE_FILE).as_str())?; state_id_str.parse::().ok() diff --git a/core-primitives/stf-state-handler/src/in_memory_state_file_io.rs b/core-primitives/stf-state-handler/src/in_memory_state_file_io.rs index c979352db4..702ccac0ab 100644 --- a/core-primitives/stf-state-handler/src/in_memory_state_file_io.rs +++ b/core-primitives/stf-state-handler/src/in_memory_state_file_io.rs @@ -224,11 +224,14 @@ fn sgx_externalities_wrapper() -> ExternalStateGenerator Result>> { - let shards = list_shards()?; + let shards: Vec = + list_shards(path).map(|iter| iter.collect()).unwrap_or_default(); Ok(create_in_memory_externalities_state_io(&shards)) } } diff --git a/core-primitives/stf-state-handler/src/test/sgx_tests.rs b/core-primitives/stf-state-handler/src/test/sgx_tests.rs index da659e06f9..42c33f512b 100644 --- a/core-primitives/stf-state-handler/src/test/sgx_tests.rs +++ b/core-primitives/stf-state-handler/src/test/sgx_tests.rs @@ -16,12 +16,7 @@ */ use crate::{ - error::{Error, Result}, - file_io::{ - purge_shard_dir, - sgx::{init_shard, shard_exists, SgxStateFileIo}, - shard_path, StateFileIo, - }, + file_io::{sgx::SgxStateFileIo, StateDir, StateFileIo}, handle_state::HandleState, in_memory_state_file_io::sgx::create_in_memory_state_io_from_shards_directories, query_shard_state::QueryShardState, @@ -56,33 +51,6 @@ type TestStateRepositoryLoader = type TestStateObserver = StateObserver; type TestStateHandler = StateHandler; -/// Directory handle to automatically initialize a directory -/// and upon dropping the reference, removing it again. -struct ShardDirectoryHandle { - shard: ShardIdentifier, -} - -impl ShardDirectoryHandle { - pub fn new(shard: ShardIdentifier) -> Result { - given_initialized_shard(&shard)?; - Ok(ShardDirectoryHandle { shard }) - } -} - -impl Drop for ShardDirectoryHandle { - fn drop(&mut self) { - purge_shard_dir(&self.shard) - } -} - -/// Gets a temporary key repository. -/// -/// We pass and ID such that it doesn't clash with other temp repositories. -fn temp_state_key_repository(id: &str) -> StateKeyRepository { - let temp_dir = TempDir::with_prefix(id).unwrap(); - get_aes_repository(temp_dir.path().to_path_buf()).unwrap() -} - // Fixme: Move this test to sgx-runtime: // // https://github.com/integritee-network/sgx-runtime/issues/23 @@ -101,8 +69,9 @@ pub fn test_sgx_state_decode_encode_works() { pub fn test_encrypt_decrypt_state_type_works() { // given let state = given_hello_world_state(); - - let state_key = temp_state_key_repository("test_encrypt_decrypt_state_type_works") + let temp_dir = TempDir::with_prefix("test_encrypt_decrypt_state_type_works").unwrap(); + let state_key = get_aes_repository(temp_dir.path().to_path_buf()) + .unwrap() .retrieve_key() .unwrap(); @@ -120,9 +89,10 @@ pub fn test_encrypt_decrypt_state_type_works() { pub fn test_write_and_load_state_works() { // given let shard: ShardIdentifier = [94u8; 32].into(); - let state_key_access = Arc::new(temp_state_key_repository("test_write_and_load_state_works")); - let (state_handler, shard_dir_handle) = - initialize_state_handler_with_directory_handle(&shard, state_key_access); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_write_and_load_state_works", &shard); + + let state_handler = initialize_state_handler(state_key_access, state_dir); let state = given_hello_world_state(); @@ -134,18 +104,15 @@ pub fn test_write_and_load_state_works() { // then assert_eq!(state.state, result_state.state); - - // clean up - std::mem::drop(shard_dir_handle); } pub fn test_ensure_subsequent_state_loads_have_same_hash() { // given let shard: ShardIdentifier = [49u8; 32].into(); - let state_key_access = - Arc::new(temp_state_key_repository("test_ensure_subsequent_state_loads_have_same_hash")); - let (state_handler, shard_dir_handle) = - initialize_state_handler_with_directory_handle(&shard, state_key_access); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_ensure_subsequent_state_loads_have_same_hash", &shard); + + let state_handler = initialize_state_handler(state_key_access, state_dir); let (lock, initial_state) = state_handler.load_for_mutation(&shard).unwrap(); state_handler.write_after_mutation(initial_state.clone(), lock, &shard).unwrap(); @@ -153,9 +120,6 @@ pub fn test_ensure_subsequent_state_loads_have_same_hash() { let (_, loaded_state_hash) = state_handler.load_cloned(&shard).unwrap(); assert_eq!(initial_state.hash(), loaded_state_hash); - - // clean up - std::mem::drop(shard_dir_handle); } pub fn test_write_access_locks_read_until_finished() { @@ -164,10 +128,10 @@ pub fn test_write_access_locks_read_until_finished() { // given let shard: ShardIdentifier = [47u8; 32].into(); - let state_key_access = - Arc::new(temp_state_key_repository("test_write_access_locks_read_until_finished")); - let (state_handler, shard_dir_handle) = - initialize_state_handler_with_directory_handle(&shard, state_key_access); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_write_access_locks_read_until_finished", &shard); + + let state_handler = initialize_state_handler(state_key_access, state_dir); let new_state_key = "my_new_state".encode(); let (lock, mut state_to_mutate) = state_handler.load_for_mutation(&shard).unwrap(); @@ -190,67 +154,60 @@ pub fn test_write_access_locks_read_until_finished() { let _hash = state_handler.write_after_mutation(state_to_mutate, lock, &shard).unwrap(); join_handle.join().unwrap(); - - // clean up - std::mem::drop(shard_dir_handle); } pub fn test_state_handler_file_backend_is_initialized() { let shard: ShardIdentifier = [11u8; 32].into(); - let state_key_access = - Arc::new(temp_state_key_repository("test_state_handler_file_backend_is_initialized")); - let (state_handler, shard_dir_handle) = - initialize_state_handler_with_directory_handle(&shard, state_key_access); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_state_handler_file_backend_is_initialized", &shard); + + let state_handler = initialize_state_handler(state_key_access, state_dir.clone()); assert!(state_handler.shard_exists(&shard).unwrap()); assert!(1 <= state_handler.list_shards().unwrap().len()); // only greater equal, because there might be other (non-test) shards present - assert_eq!(1, number_of_files_in_shard_dir(&shard).unwrap()); // creates a first initialized file + assert_eq!(1, state_dir.list_state_ids_for_shard(&shard).unwrap().len()); // creates a first initialized file let _state = state_handler.load_cloned(&shard).unwrap(); - assert_eq!(1, number_of_files_in_shard_dir(&shard).unwrap()); - - // clean up - std::mem::drop(shard_dir_handle); + assert_eq!(1, state_dir.list_state_ids_for_shard(&shard).unwrap().len()); } pub fn test_multiple_state_updates_create_snapshots_up_to_cache_size() { let shard: ShardIdentifier = [17u8; 32].into(); - let state_key_access = Arc::new(temp_state_key_repository( - "test_multiple_state_updates_create_snapshots_up_to_cache_size", - )); - let (state_handler, _shard_dir_handle) = - initialize_state_handler_with_directory_handle(&shard, state_key_access); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_state_handler_file_backend_is_initialized", &shard); - assert_eq!(1, number_of_files_in_shard_dir(&shard).unwrap()); + let state_handler = initialize_state_handler(state_key_access, state_dir.clone()); + + assert_eq!(1, state_dir.list_state_ids_for_shard(&shard).unwrap().len()); let hash_1 = update_state( state_handler.as_ref(), &shard, ("my_key_1".encode(), "mega_secret_value".encode()), ); - assert_eq!(2, number_of_files_in_shard_dir(&shard).unwrap()); + assert_eq!(2, state_dir.list_state_ids_for_shard(&shard).unwrap().len()); let hash_2 = update_state( state_handler.as_ref(), &shard, ("my_key_2".encode(), "mega_secret_value222".encode()), ); - assert_eq!(3, number_of_files_in_shard_dir(&shard).unwrap()); + assert_eq!(3, state_dir.list_state_ids_for_shard(&shard).unwrap().len()); let hash_3 = update_state( state_handler.as_ref(), &shard, ("my_key_3".encode(), "mega_secret_value3".encode()), ); - assert_eq!(3, number_of_files_in_shard_dir(&shard).unwrap()); + assert_eq!(3, state_dir.list_state_ids_for_shard(&shard).unwrap().len()); let hash_4 = update_state( state_handler.as_ref(), &shard, ("my_key_3".encode(), "mega_secret_valuenot3".encode()), ); - assert_eq!(3, number_of_files_in_shard_dir(&shard).unwrap()); + assert_eq!(3, state_dir.list_state_ids_for_shard(&shard).unwrap().len()); assert_ne!(hash_1, hash_2); assert_ne!(hash_1, hash_3); @@ -259,15 +216,18 @@ pub fn test_multiple_state_updates_create_snapshots_up_to_cache_size() { assert_ne!(hash_2, hash_4); assert_ne!(hash_3, hash_4); - assert_eq!(STATE_SNAPSHOTS_CACHE_SIZE, number_of_files_in_shard_dir(&shard).unwrap()); + assert_eq!( + STATE_SNAPSHOTS_CACHE_SIZE, + state_dir.list_state_ids_for_shard(&shard).unwrap().len() + ); } pub fn test_file_io_get_state_hash_works() { let shard: ShardIdentifier = [21u8; 32].into(); - let _shard_dir_handle = ShardDirectoryHandle::new(shard).unwrap(); - let state_key_access = Arc::new(temp_state_key_repository("test_file_io_get_state_hash_works")); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_file_io_get_state_hash_works", &shard); - let file_io = TestStateFileIo::new(state_key_access); + let file_io = TestStateFileIo::new(state_key_access, state_dir); let state_id = 1234u128; let state_hash = file_io @@ -281,10 +241,10 @@ pub fn test_file_io_get_state_hash_works() { pub fn test_state_files_from_handler_can_be_loaded_again() { let shard: ShardIdentifier = [15u8; 32].into(); - let state_key_access = - Arc::new(temp_state_key_repository("test_state_files_from_handler_can_be_loaded_again")); - let (state_handler, _shard_dir_handle) = - initialize_state_handler_with_directory_handle(&shard, state_key_access.clone()); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_state_files_from_handler_can_be_loaded_again", &shard); + + let state_handler = initialize_state_handler(state_key_access.clone(), state_dir.clone()); update_state(state_handler.as_ref(), &shard, ("test_key_1".encode(), "value1".encode())); update_state(state_handler.as_ref(), &shard, ("test_key_2".encode(), "value2".encode())); @@ -296,9 +256,12 @@ pub fn test_state_files_from_handler_can_be_loaded_again() { update_state(state_handler.as_ref(), &shard, ("test_key_3".encode(), "value3".encode())); // We initialize another state handler to load the state from the changes we just made. - let updated_state_handler = initialize_state_handler(state_key_access); + let updated_state_handler = initialize_state_handler(state_key_access, state_dir.clone()); - assert_eq!(STATE_SNAPSHOTS_CACHE_SIZE, number_of_files_in_shard_dir(&shard).unwrap()); + assert_eq!( + STATE_SNAPSHOTS_CACHE_SIZE, + state_dir.list_state_ids_for_shard(&shard).unwrap().len() + ); assert_eq!( &"value3".encode(), updated_state_handler @@ -313,15 +276,12 @@ pub fn test_state_files_from_handler_can_be_loaded_again() { pub fn test_list_state_ids_ignores_files_not_matching_the_pattern() { let shard: ShardIdentifier = [21u8; 32].into(); - let _shard_dir_handle = ShardDirectoryHandle::new(shard).unwrap(); - let state_key_access = Arc::new(temp_state_key_repository( - "test_list_state_ids_ignores_files_not_matching_the_pattern", - )); + let (_temp_dir, state_key_access, state_dir) = + test_setup("test_list_state_ids_ignores_files_not_matching_the_pattern", &shard); - let file_io = TestStateFileIo::new(state_key_access); + let file_io = TestStateFileIo::new(state_key_access, state_dir.clone()); - let mut invalid_state_file_path = shard_path(&shard); - invalid_state_file_path.push("invalid-state.bin"); + let invalid_state_file_path = state_dir.shard_path(&shard).join("invalid-state.bin"); write(&[0, 1, 2, 3, 4, 5], invalid_state_file_path).unwrap(); file_io @@ -333,9 +293,11 @@ pub fn test_list_state_ids_ignores_files_not_matching_the_pattern() { pub fn test_in_memory_state_initializes_from_shard_directory() { let shard: ShardIdentifier = [45u8; 32].into(); - let _shard_dir_handle = ShardDirectoryHandle::new(shard).unwrap(); + let (_temp_dir, _, state_dir) = + test_setup("test_list_state_ids_ignores_files_not_matching_the_pattern", &shard); - let file_io = create_in_memory_state_io_from_shards_directories().unwrap(); + let file_io = + create_in_memory_state_io_from_shards_directories(&state_dir.shards_directory()).unwrap(); let state_initializer = Arc::new(TestStateInitializer::new(StfState::new(Default::default()))); let state_repository_loader = StateSnapshotRepositoryLoader::new(file_io.clone(), state_initializer); @@ -347,16 +309,11 @@ pub fn test_in_memory_state_initializes_from_shard_directory() { assert!(state_snapshot_repository.shard_exists(&shard)); } -fn initialize_state_handler_with_directory_handle( - shard: &ShardIdentifier, +fn initialize_state_handler( state_key_access: Arc, -) -> (Arc, ShardDirectoryHandle) { - let shard_dir_handle = ShardDirectoryHandle::new(*shard).unwrap(); - (initialize_state_handler(state_key_access), shard_dir_handle) -} - -fn initialize_state_handler(state_key_access: Arc) -> Arc { - let file_io = Arc::new(TestStateFileIo::new(state_key_access)); + state_dir: StateDir, +) -> Arc { + let file_io = Arc::new(TestStateFileIo::new(state_key_access, state_dir)); let state_initializer = Arc::new(TestStateInitializer::new(StfState::new(Default::default()))); let state_repository_loader = TestStateRepositoryLoader::new(file_io, state_initializer.clone()); @@ -392,19 +349,11 @@ fn given_hello_world_state() -> StfState { state } -fn given_initialized_shard(shard: &ShardIdentifier) -> Result<()> { - if shard_exists(&shard) { - purge_shard_dir(shard); - } - init_shard(&shard) -} - -fn number_of_files_in_shard_dir(shard: &ShardIdentifier) -> Result { - let shard_dir_path = shard_path(shard); - let files_in_dir = - std::fs::read_dir(shard_dir_path.clone()).map_err(|e| Error::Other(e.into()))?; - - log::info!("File in shard dir: {:?}", files_in_dir); +fn test_setup(id: &str, shard: &ShardIdentifier) -> (TempDir, Arc, StateDir) { + let temp_dir = TempDir::with_prefix(id).unwrap(); + let state_key_access = Arc::new(get_aes_repository(temp_dir.path().to_path_buf()).unwrap()); + let state_dir = StateDir::new(temp_dir.path().to_path_buf()); + state_dir.given_initialized_shard(shard); - Ok(files_in_dir.count()) + (temp_dir, state_key_access, state_dir) } diff --git a/enclave-runtime/src/initialization/mod.rs b/enclave-runtime/src/initialization/mod.rs index d7d971fa4d..98fb9bc8f2 100644 --- a/enclave-runtime/src/initialization/mod.rs +++ b/enclave-runtime/src/initialization/mod.rs @@ -64,7 +64,7 @@ use itp_sgx_crypto::{ get_aes_repository, get_ed25519_repository, get_rsa3072_repository, key_repository::AccessKey, }; use itp_stf_state_handler::{ - handle_state::HandleState, query_shard_state::QueryShardState, + file_io::StateDir, handle_state::HandleState, query_shard_state::QueryShardState, state_snapshot_repository::VersionedStateAccess, state_snapshot_repository_loader::StateSnapshotRepositoryLoader, StateHandler, }; @@ -91,10 +91,11 @@ pub(crate) fn init_enclave( // Create the aes key that is used for state encryption such that a key is always present in tests. // It will be overwritten anyway if mutual remote attestation is performed with the primary worker. - let state_key_repository = Arc::new(get_aes_repository(base_dir)?); + let state_key_repository = Arc::new(get_aes_repository(base_dir.clone())?); GLOBAL_STATE_KEY_REPOSITORY_COMPONENT.initialize(state_key_repository.clone()); - let state_file_io = Arc::new(EnclaveStateFileIo::new(state_key_repository)); + let state_file_io = + Arc::new(EnclaveStateFileIo::new(state_key_repository, StateDir::new(base_dir))); let state_initializer = Arc::new(EnclaveStateInitializer::new(shielding_key_repository.clone())); let state_snapshot_repository_loader = StateSnapshotRepositoryLoader::<