diff --git a/CHANGELOG.md b/CHANGELOG.md index c28ca2be5..649fe3f1c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- Introduce `DatabaseFactory` trait. ## [v0.20.0] - [v0.19.0] diff --git a/src/database/any.rs b/src/database/any.rs index bbd9d41a5..06cc66925 100644 --- a/src/database/any.rs +++ b/src/database/any.rs @@ -425,3 +425,42 @@ impl ConfigurableDatabase for AnyDatabase { impl_from!((), AnyDatabaseConfig, Memory,); impl_from!(SledDbConfiguration, AnyDatabaseConfig, Sled, #[cfg(feature = "key-value-db")]); impl_from!(SqliteDbConfiguration, AnyDatabaseConfig, Sqlite, #[cfg(feature = "sqlite")]); + +/// Type that implements [`DatabaseFactory`] that builds [`AnyDatabase`]. +pub enum AnyDatabaseFactory { + /// Memory database factory + Memory(memory::MemoryDatabaseFactory), + #[cfg(feature = "key-value-db")] + #[cfg_attr(docsrs, doc(cfg(feature = "key-value-db")))] + /// Key-value database factory + Sled(sled::Db), + #[cfg(feature = "sqlite")] + #[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))] + /// Sqlite database factory + Sqlite(sqlite::SqliteDatabaseFactory), +} + +impl DatabaseFactory for AnyDatabaseFactory { + type Inner = AnyDatabase; + + fn build( + &self, + descriptor: &ExtendedDescriptor, + network: Network, + secp: &SecpCtx, + ) -> Result { + match self { + AnyDatabaseFactory::Memory(f) => { + f.build(descriptor, network, secp).map(Self::Inner::Memory) + } + #[cfg(feature = "key-value-db")] + AnyDatabaseFactory::Sled(f) => { + f.build(descriptor, network, secp).map(Self::Inner::Sled) + } + #[cfg(feature = "sqlite")] + AnyDatabaseFactory::Sqlite(f) => { + f.build(descriptor, network, secp).map(Self::Inner::Sqlite) + } + } + } +} diff --git a/src/database/keyvalue.rs b/src/database/keyvalue.rs index e10ada1fa..e9162f00d 100644 --- a/src/database/keyvalue.rs +++ b/src/database/keyvalue.rs @@ -21,6 +21,9 @@ use crate::database::memory::MapKey; use crate::database::{BatchDatabase, BatchOperations, Database, SyncTime}; use crate::error::Error; use crate::types::*; +use crate::wallet::wallet_name_from_descriptor; + +use super::DatabaseFactory; macro_rules! impl_batch_operations { ( { $($after_insert:tt)* }, $process_delete:ident ) => { @@ -402,6 +405,21 @@ impl BatchDatabase for Tree { } } +/// A [`DatabaseFactory`] implementation that builds [`Tree`] +impl DatabaseFactory for sled::Db { + type Inner = sled::Tree; + + fn build( + &self, + descriptor: &crate::descriptor::ExtendedDescriptor, + network: bitcoin::Network, + secp: &crate::wallet::utils::SecpCtx, + ) -> Result { + let name = wallet_name_from_descriptor(descriptor.clone(), None, network, secp)?; + self.open_tree(&name).map_err(Error::Sled) + } +} + #[cfg(test)] mod test { use lazy_static::lazy_static; @@ -492,4 +510,16 @@ mod test { fn test_sync_time() { crate::database::test::test_sync_time(get_tree()); } + + #[test] + fn test_factory() { + let time = SystemTime::now().duration_since(UNIX_EPOCH).unwrap(); + let mut dir = std::env::temp_dir(); + dir.push(format!("bdk_{}", time.as_nanos())); + + let fac = sled::open(&dir).unwrap(); + crate::database::test::test_factory(&fac); + + std::fs::remove_dir_all(&dir).unwrap(); + } } diff --git a/src/database/memory.rs b/src/database/memory.rs index 7d806eb4a..ccc8d8175 100644 --- a/src/database/memory.rs +++ b/src/database/memory.rs @@ -26,6 +26,8 @@ use crate::database::{BatchDatabase, BatchOperations, ConfigurableDatabase, Data use crate::error::Error; use crate::types::*; +use super::DatabaseFactory; + // path -> script p{i,e} -> script // script -> path s