diff --git a/crates/chain/src/lib.rs b/crates/chain/src/lib.rs index a756ab11c..3fb8c0eda 100644 --- a/crates/chain/src/lib.rs +++ b/crates/chain/src/lib.rs @@ -37,8 +37,6 @@ pub use tx_data_traits::*; pub use tx_graph::TxGraph; mod chain_oracle; pub use chain_oracle::*; -mod persist; -pub use persist::*; #[doc(hidden)] pub mod example_utils; diff --git a/crates/chain/src/persist.rs b/crates/chain/src/persist.rs deleted file mode 100644 index 2ec88f636..000000000 --- a/crates/chain/src/persist.rs +++ /dev/null @@ -1,169 +0,0 @@ -use core::{ - future::Future, - ops::{Deref, DerefMut}, - pin::Pin, -}; - -use alloc::boxed::Box; - -use crate::Merge; - -/// Represents a type that contains staged changes. -pub trait Staged { - /// Type for staged changes. - type ChangeSet: Merge; - - /// Get mutable reference of staged changes. - fn staged(&mut self) -> &mut Self::ChangeSet; -} - -/// Trait that persists the type with `Db`. -/// -/// Methods of this trait should not be called directly. -pub trait PersistWith: Staged + Sized { - /// Parameters for [`PersistWith::create`]. - type CreateParams; - /// Parameters for [`PersistWith::load`]. - type LoadParams; - /// Error type of [`PersistWith::create`]. - type CreateError; - /// Error type of [`PersistWith::load`]. - type LoadError; - /// Error type of [`PersistWith::persist`]. - type PersistError; - - /// Initialize the `Db` and create `Self`. - fn create(db: &mut Db, params: Self::CreateParams) -> Result; - - /// Initialize the `Db` and load a previously-persisted `Self`. - fn load(db: &mut Db, params: Self::LoadParams) -> Result, Self::LoadError>; - - /// Persist changes to the `Db`. - fn persist( - db: &mut Db, - changeset: &::ChangeSet, - ) -> Result<(), Self::PersistError>; -} - -type FutureResult<'a, T, E> = Pin> + Send + 'a>>; - -/// Trait that persists the type with an async `Db`. -pub trait PersistAsyncWith: Staged + Sized { - /// Parameters for [`PersistAsyncWith::create`]. - type CreateParams; - /// Parameters for [`PersistAsyncWith::load`]. - type LoadParams; - /// Error type of [`PersistAsyncWith::create`]. - type CreateError; - /// Error type of [`PersistAsyncWith::load`]. - type LoadError; - /// Error type of [`PersistAsyncWith::persist`]. - type PersistError; - - /// Initialize the `Db` and create `Self`. - fn create(db: &mut Db, params: Self::CreateParams) -> FutureResult; - - /// Initialize the `Db` and load a previously-persisted `Self`. - fn load(db: &mut Db, params: Self::LoadParams) -> FutureResult, Self::LoadError>; - - /// Persist changes to the `Db`. - fn persist<'a>( - db: &'a mut Db, - changeset: &'a ::ChangeSet, - ) -> FutureResult<'a, (), Self::PersistError>; -} - -/// Represents a persisted `T`. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct Persisted { - inner: T, -} - -impl Persisted { - /// Create a new persisted `T`. - pub fn create(db: &mut Db, params: T::CreateParams) -> Result - where - T: PersistWith, - { - T::create(db, params).map(|inner| Self { inner }) - } - - /// Create a new persisted `T` with async `Db`. - pub async fn create_async( - db: &mut Db, - params: T::CreateParams, - ) -> Result - where - T: PersistAsyncWith, - { - T::create(db, params).await.map(|inner| Self { inner }) - } - - /// Construct a persisted `T` from `Db`. - pub fn load(db: &mut Db, params: T::LoadParams) -> Result, T::LoadError> - where - T: PersistWith, - { - Ok(T::load(db, params)?.map(|inner| Self { inner })) - } - - /// Construct a persisted `T` from an async `Db`. - pub async fn load_async( - db: &mut Db, - params: T::LoadParams, - ) -> Result, T::LoadError> - where - T: PersistAsyncWith, - { - Ok(T::load(db, params).await?.map(|inner| Self { inner })) - } - - /// Persist staged changes of `T` into `Db`. - /// - /// If the database errors, the staged changes will not be cleared. - pub fn persist(&mut self, db: &mut Db) -> Result - where - T: PersistWith, - { - let stage = T::staged(&mut self.inner); - if stage.is_empty() { - return Ok(false); - } - T::persist(db, &*stage)?; - stage.take(); - Ok(true) - } - - /// Persist staged changes of `T` into an async `Db`. - /// - /// If the database errors, the staged changes will not be cleared. - pub async fn persist_async<'a, Db>( - &'a mut self, - db: &'a mut Db, - ) -> Result - where - T: PersistAsyncWith, - { - let stage = T::staged(&mut self.inner); - if stage.is_empty() { - return Ok(false); - } - T::persist(db, &*stage).await?; - stage.take(); - Ok(true) - } -} - -impl Deref for Persisted { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl DerefMut for Persisted { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.inner - } -} diff --git a/crates/chain/src/rusqlite_impl.rs b/crates/chain/src/rusqlite_impl.rs index a52c491c6..d8ef65c42 100644 --- a/crates/chain/src/rusqlite_impl.rs +++ b/crates/chain/src/rusqlite_impl.rs @@ -225,7 +225,7 @@ where pub const ANCHORS_TABLE_NAME: &'static str = "bdk_anchors"; /// Initialize sqlite tables. - fn init_sqlite_tables(db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { + pub fn init_sqlite_tables(db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { let schema_v0: &[&str] = &[ // full transactions &format!( @@ -264,9 +264,9 @@ where } /// Construct a [`TxGraph`] from an sqlite database. + /// + /// Remember to call [`Self::init_sqlite_tables`] beforehand. pub fn from_sqlite(db_tx: &rusqlite::Transaction) -> rusqlite::Result { - Self::init_sqlite_tables(db_tx)?; - let mut changeset = Self::default(); let mut statement = db_tx.prepare(&format!( @@ -332,9 +332,9 @@ where } /// Persist `changeset` to the sqlite database. + /// + /// Remember to call [`Self::init_sqlite_tables`] beforehand. pub fn persist_to_sqlite(&self, db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { - Self::init_sqlite_tables(db_tx)?; - let mut statement = db_tx.prepare_cached(&format!( "INSERT INTO {}(txid, raw_tx) VALUES(:txid, :raw_tx) ON CONFLICT(txid) DO UPDATE SET raw_tx=:raw_tx", Self::TXS_TABLE_NAME, @@ -396,7 +396,7 @@ impl local_chain::ChangeSet { pub const BLOCKS_TABLE_NAME: &'static str = "bdk_blocks"; /// Initialize sqlite tables for persisting [`local_chain::LocalChain`]. - fn init_sqlite_tables(db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { + pub fn init_sqlite_tables(db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { let schema_v0: &[&str] = &[ // blocks &format!( @@ -411,9 +411,9 @@ impl local_chain::ChangeSet { } /// Construct a [`LocalChain`](local_chain::LocalChain) from sqlite database. + /// + /// Remember to call [`Self::init_sqlite_tables`] beforehand. pub fn from_sqlite(db_tx: &rusqlite::Transaction) -> rusqlite::Result { - Self::init_sqlite_tables(db_tx)?; - let mut changeset = Self::default(); let mut statement = db_tx.prepare(&format!( @@ -435,9 +435,9 @@ impl local_chain::ChangeSet { } /// Persist `changeset` to the sqlite database. + /// + /// Remember to call [`Self::init_sqlite_tables`] beforehand. pub fn persist_to_sqlite(&self, db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { - Self::init_sqlite_tables(db_tx)?; - let mut replace_statement = db_tx.prepare_cached(&format!( "REPLACE INTO {}(block_height, block_hash) VALUES(:block_height, :block_hash)", Self::BLOCKS_TABLE_NAME, @@ -471,7 +471,7 @@ impl keychain_txout::ChangeSet { /// Initialize sqlite tables for persisting /// [`KeychainTxOutIndex`](keychain_txout::KeychainTxOutIndex). - fn init_sqlite_tables(db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { + pub fn init_sqlite_tables(db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { let schema_v0: &[&str] = &[ // last revealed &format!( @@ -487,9 +487,9 @@ impl keychain_txout::ChangeSet { /// Construct [`KeychainTxOutIndex`](keychain_txout::KeychainTxOutIndex) from sqlite database /// and given parameters. + /// + /// Remember to call [`Self::init_sqlite_tables`] beforehand. pub fn from_sqlite(db_tx: &rusqlite::Transaction) -> rusqlite::Result { - Self::init_sqlite_tables(db_tx)?; - let mut changeset = Self::default(); let mut statement = db_tx.prepare(&format!( @@ -511,9 +511,9 @@ impl keychain_txout::ChangeSet { } /// Persist `changeset` to the sqlite database. + /// + /// Remember to call [`Self::init_sqlite_tables`] beforehand. pub fn persist_to_sqlite(&self, db_tx: &rusqlite::Transaction) -> rusqlite::Result<()> { - Self::init_sqlite_tables(db_tx)?; - let mut statement = db_tx.prepare_cached(&format!( "REPLACE INTO {}(descriptor_id, last_revealed) VALUES(:descriptor_id, :last_revealed)", Self::LAST_REVEALED_TABLE_NAME, diff --git a/crates/wallet/src/wallet/changeset.rs b/crates/wallet/src/wallet/changeset.rs index 5f3b9b3dc..2d4b700ed 100644 --- a/crates/wallet/src/wallet/changeset.rs +++ b/crates/wallet/src/wallet/changeset.rs @@ -72,10 +72,8 @@ impl ChangeSet { /// Name of table to store wallet descriptors and network. pub const WALLET_TABLE_NAME: &'static str = "bdk_wallet"; - /// Initialize sqlite tables for wallet schema & table. - fn init_wallet_sqlite_tables( - db_tx: &chain::rusqlite::Transaction, - ) -> chain::rusqlite::Result<()> { + /// Initialize sqlite tables for wallet tables. + pub fn init_sqlite_tables(db_tx: &chain::rusqlite::Transaction) -> chain::rusqlite::Result<()> { let schema_v0: &[&str] = &[&format!( "CREATE TABLE {} ( \ id INTEGER PRIMARY KEY NOT NULL CHECK (id = 0), \ @@ -85,12 +83,17 @@ impl ChangeSet { ) STRICT;", Self::WALLET_TABLE_NAME, )]; - crate::rusqlite_impl::migrate_schema(db_tx, Self::WALLET_SCHEMA_NAME, &[schema_v0]) + crate::rusqlite_impl::migrate_schema(db_tx, Self::WALLET_SCHEMA_NAME, &[schema_v0])?; + + bdk_chain::local_chain::ChangeSet::init_sqlite_tables(db_tx)?; + bdk_chain::tx_graph::ChangeSet::::init_sqlite_tables(db_tx)?; + bdk_chain::keychain_txout::ChangeSet::init_sqlite_tables(db_tx)?; + + Ok(()) } /// Recover a [`ChangeSet`] from sqlite database. pub fn from_sqlite(db_tx: &chain::rusqlite::Transaction) -> chain::rusqlite::Result { - Self::init_wallet_sqlite_tables(db_tx)?; use chain::rusqlite::OptionalExtension; use chain::Impl; @@ -129,7 +132,6 @@ impl ChangeSet { &self, db_tx: &chain::rusqlite::Transaction, ) -> chain::rusqlite::Result<()> { - Self::init_wallet_sqlite_tables(db_tx)?; use chain::rusqlite::named_params; use chain::Impl; diff --git a/crates/wallet/src/wallet/mod.rs b/crates/wallet/src/wallet/mod.rs index f98b16e91..4cd721ba1 100644 --- a/crates/wallet/src/wallet/mod.rs +++ b/crates/wallet/src/wallet/mod.rs @@ -45,7 +45,6 @@ use bitcoin::{ use bitcoin::{consensus::encode::serialize, transaction, BlockHash, Psbt}; use bitcoin::{constants::genesis_block, Amount}; use bitcoin::{secp256k1::Secp256k1, Weight}; -use chain::Staged; use core::fmt; use core::mem; use core::ops::Deref; @@ -123,14 +122,6 @@ pub struct Wallet { secp: SecpCtx, } -impl Staged for Wallet { - type ChangeSet = ChangeSet; - - fn staged(&mut self) -> &mut Self::ChangeSet { - &mut self.stage - } -} - /// An update to [`Wallet`]. /// /// It updates [`KeychainTxOutIndex`], [`bdk_chain::TxGraph`] and [`local_chain::LocalChain`] atomically. @@ -2303,7 +2294,7 @@ impl Wallet { Ok(()) } - /// Get a reference of the staged [`ChangeSet`] that are yet to be committed (if any). + /// Get a reference of the staged [`ChangeSet`] that is yet to be committed (if any). pub fn staged(&self) -> Option<&ChangeSet> { if self.stage.is_empty() { None @@ -2312,6 +2303,15 @@ impl Wallet { } } + /// Get a mutable reference of the staged [`ChangeSet`] that is yet to be commited (if any). + pub fn staged_mut(&mut self) -> Option<&mut ChangeSet> { + if self.stage.is_empty() { + None + } else { + Some(&mut self.stage) + } + } + /// Take the staged [`ChangeSet`] to be persisted now (if any). pub fn take_staged(&mut self) -> Option { self.stage.take() diff --git a/crates/wallet/src/wallet/params.rs b/crates/wallet/src/wallet/params.rs index 9b0795395..f91034002 100644 --- a/crates/wallet/src/wallet/params.rs +++ b/crates/wallet/src/wallet/params.rs @@ -1,12 +1,13 @@ use alloc::boxed::Box; -use bdk_chain::{keychain_txout::DEFAULT_LOOKAHEAD, PersistAsyncWith, PersistWith}; +use bdk_chain::keychain_txout::DEFAULT_LOOKAHEAD; use bitcoin::{BlockHash, Network}; use miniscript::descriptor::KeyMap; use crate::{ descriptor::{DescriptorError, ExtendedDescriptor, IntoWalletDescriptor}, utils::SecpCtx, - KeychainKind, Wallet, + AsyncWalletPersister, CreateWithPersistError, KeychainKind, LoadWithPersistError, Wallet, + WalletPersister, }; use super::{ChangeSet, LoadError, PersistedWallet}; @@ -108,26 +109,26 @@ impl CreateParams { self } - /// Create [`PersistedWallet`] with the given `Db`. - pub fn create_wallet( + /// Create [`PersistedWallet`] with the given [`WalletPersister`]. + pub fn create_wallet

( self, - db: &mut Db, - ) -> Result>::CreateError> + persister: &mut P, + ) -> Result, CreateWithPersistError> where - Wallet: PersistWith, + P: WalletPersister, { - PersistedWallet::create(db, self) + PersistedWallet::create(persister, self) } - /// Create [`PersistedWallet`] with the given async `Db`. - pub async fn create_wallet_async( + /// Create [`PersistedWallet`] with the given [`AsyncWalletPersister`]. + pub async fn create_wallet_async

( self, - db: &mut Db, - ) -> Result>::CreateError> + persister: &mut P, + ) -> Result, CreateWithPersistError> where - Wallet: PersistAsyncWith, + P: AsyncWalletPersister, { - PersistedWallet::create_async(db, self).await + PersistedWallet::create_async(persister, self).await } /// Create [`Wallet`] without persistence. @@ -219,26 +220,26 @@ impl LoadParams { self } - /// Load [`PersistedWallet`] with the given `Db`. - pub fn load_wallet( + /// Load [`PersistedWallet`] with the given [`WalletPersister`]. + pub fn load_wallet

( self, - db: &mut Db, - ) -> Result, >::LoadError> + persister: &mut P, + ) -> Result>, LoadWithPersistError> where - Wallet: PersistWith, + P: WalletPersister, { - PersistedWallet::load(db, self) + PersistedWallet::load(persister, self) } - /// Load [`PersistedWallet`] with the given async `Db`. - pub async fn load_wallet_async( + /// Load [`PersistedWallet`] with the given [`AsyncWalletPersister`]. + pub async fn load_wallet_async

( self, - db: &mut Db, - ) -> Result, >::LoadError> + persister: &mut P, + ) -> Result>, LoadWithPersistError> where - Wallet: PersistAsyncWith, + P: AsyncWalletPersister, { - PersistedWallet::load_async(db, self).await + PersistedWallet::load_async(persister, self).await } /// Load [`Wallet`] without persistence. diff --git a/crates/wallet/src/wallet/persisted.rs b/crates/wallet/src/wallet/persisted.rs index cc9f267f4..a8876e8e4 100644 --- a/crates/wallet/src/wallet/persisted.rs +++ b/crates/wallet/src/wallet/persisted.rs @@ -1,130 +1,330 @@ -use core::fmt; +use core::{ + fmt, + future::Future, + marker::PhantomData, + ops::{Deref, DerefMut}, + pin::Pin, +}; -use crate::{descriptor::DescriptorError, Wallet}; +use alloc::boxed::Box; +use chain::Merge; -/// Represents a persisted wallet. -pub type PersistedWallet = bdk_chain::Persisted; +use crate::{descriptor::DescriptorError, ChangeSet, CreateParams, LoadParams, Wallet}; -#[cfg(feature = "rusqlite")] -impl<'c> chain::PersistWith> for Wallet { - type CreateParams = crate::CreateParams; - type LoadParams = crate::LoadParams; - - type CreateError = CreateWithPersistError; - type LoadError = LoadWithPersistError; - type PersistError = bdk_chain::rusqlite::Error; - - fn create( - db: &mut bdk_chain::rusqlite::Transaction<'c>, - params: Self::CreateParams, - ) -> Result { - let mut wallet = - Self::create_with_params(params).map_err(CreateWithPersistError::Descriptor)?; - if let Some(changeset) = wallet.take_staged() { - changeset - .persist_to_sqlite(db) +/// Trait that persists [`PersistedWallet`]. +/// +/// For an async version, use [`AsyncWalletPersister`]. +/// +/// Associated functions of this trait should not be called directly, and the trait is designed so +/// that associated functions are hard to find (since they are not methods!). [`WalletPersister`] is +/// used by [`PersistedWallet`] (a light wrapper around [`Wallet`]) which enforces some level of +/// safety. Refer to [`PersistedWallet`] for more about the safety checks. +pub trait WalletPersister { + /// Error type of the persister. + type Error; + + /// Initialize the `persister` and load all data. + /// + /// This is called by [`PersistedWallet::create`] and [`PersistedWallet::load`] to ensure + /// the [`WalletPersister`] is initialized and returns all data in the `persister`. + /// + /// # Implementation Details + /// + /// The database schema of the `persister` (if any), should be initialized and migrated here. + /// + /// The implementation must return all data currently stored in the `persister`. If there is no + /// data, return an empty changeset (using [`ChangeSet::default()`]). + /// + /// Error should only occur on database failure. Multiple calls to `initialize` should not + /// error. Calling `initialize` inbetween calls to `persist` should not error. + /// + /// Calling [`persist`] before the `persister` is `initialize`d may error. However, some + /// persister implementations may NOT require initialization at all (and not error). + /// + /// [`persist`]: WalletPersister::persist + fn initialize(persister: &mut Self) -> Result; + + /// Persist the given `changeset` to the `persister`. + /// + /// This method can fail if the `persister` is not [`initialize`]d. + /// + /// [`initialize`]: WalletPersister::initialize + fn persist(persister: &mut Self, changeset: &ChangeSet) -> Result<(), Self::Error>; +} + +type FutureResult<'a, T, E> = Pin> + Send + 'a>>; + +/// Async trait that persists [`PersistedWallet`]. +/// +/// For a blocking version, use [`WalletPersister`]. +/// +/// Associated functions of this trait should not be called directly, and the trait is designed so +/// that associated functions are hard to find (since they are not methods!). [`AsyncWalletPersister`] is +/// used by [`PersistedWallet`] (a light wrapper around [`Wallet`]) which enforces some level of +/// safety. Refer to [`PersistedWallet`] for more about the safety checks. +pub trait AsyncWalletPersister { + /// Error type of the persister. + type Error; + + /// Initialize the `persister` and load all data. + /// + /// This is called by [`PersistedWallet::create_async`] and [`PersistedWallet::load_async`] to + /// ensure the [`AsyncWalletPersister`] is initialized and returns all data in the `persister`. + /// + /// # Implementation Details + /// + /// The database schema of the `persister` (if any), should be initialized and migrated here. + /// + /// The implementation must return all data currently stored in the `persister`. If there is no + /// data, return an empty changeset (using [`ChangeSet::default()`]). + /// + /// Error should only occur on database failure. Multiple calls to `initialize` should not + /// error. Calling `initialize` inbetween calls to `persist` should not error. + /// + /// Calling [`persist`] before the `persister` is `initialize`d may error. However, some + /// persister implementations may NOT require initialization at all (and not error). + /// + /// [`persist`]: AsyncWalletPersister::persist + fn initialize<'a>(persister: &'a mut Self) -> FutureResult<'a, ChangeSet, Self::Error> + where + Self: 'a; + + /// Persist the given `changeset` to the `persister`. + /// + /// This method can fail if the `persister` is not [`initialize`]d. + /// + /// [`initialize`]: AsyncWalletPersister::initialize + fn persist<'a>( + persister: &'a mut Self, + changeset: &'a ChangeSet, + ) -> FutureResult<'a, (), Self::Error> + where + Self: 'a; +} + +/// Represents a persisted wallet which persists into type `P`. +/// +/// This is a light wrapper around [`Wallet`] that enforces some level of safety-checking when used +/// with a [`WalletPersister`] or [`AsyncWalletPersister`] implementation. Safety checks assume that +/// [`WalletPersister`] and/or [`AsyncWalletPersister`] are implemented correctly. +/// +/// Checks include: +/// +/// * Ensure the persister is initialized before data is persisted. +/// * Ensure there were no previously persisted wallet data before creating a fresh wallet and +/// persisting it. +/// * Only clear the staged changes of [`Wallet`] after persisting succeeds. +/// * Ensure the wallet is persisted to the same `P` type as when created/loaded. Note that this is +/// not completely fool-proof as you can have multiple instances of the same `P` type that are +/// connected to different databases. +#[derive(Debug)] +pub struct PersistedWallet

{ + inner: Wallet, + marker: PhantomData

, +} + +impl

Deref for PersistedWallet

{ + type Target = Wallet; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl

DerefMut for PersistedWallet

{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +/// Methods when `P` is a [`WalletPersister`]. +impl PersistedWallet

{ + /// Create a new [`PersistedWallet`] with the given `persister` and `params`. + pub fn create( + persister: &mut P, + params: CreateParams, + ) -> Result> { + let existing = P::initialize(persister).map_err(CreateWithPersistError::Persist)?; + if !existing.is_empty() { + return Err(CreateWithPersistError::DataAlreadyExists(existing)); + } + let mut inner = + Wallet::create_with_params(params).map_err(CreateWithPersistError::Descriptor)?; + if let Some(changeset) = inner.take_staged() { + P::persist(persister, &changeset).map_err(CreateWithPersistError::Persist)?; + } + Ok(Self { + inner, + marker: PhantomData, + }) + } + + /// Load a previously [`PersistedWallet`] from the given `persister` and `params`. + pub fn load( + persister: &mut P, + params: LoadParams, + ) -> Result, LoadWithPersistError> { + let changeset = P::initialize(persister).map_err(LoadWithPersistError::Persist)?; + Wallet::load_with_params(changeset, params) + .map(|opt| { + opt.map(|inner| PersistedWallet { + inner, + marker: PhantomData, + }) + }) + .map_err(LoadWithPersistError::InvalidChangeSet) + } + + /// Persist staged changes of wallet into `persister`. + /// + /// Returns whether any new changes were persisted. + /// + /// If the `persister` errors, the staged changes will not be cleared. + pub fn persist(&mut self, persister: &mut P) -> Result { + match self.inner.staged_mut() { + Some(stage) => { + P::persist(persister, &*stage)?; + let _ = stage.take(); + Ok(true) + } + None => Ok(false), + } + } +} + +/// Methods when `P` is an [`AsyncWalletPersister`]. +impl PersistedWallet

{ + /// Create a new [`PersistedWallet`] with the given async `persister` and `params`. + pub async fn create_async( + persister: &mut P, + params: CreateParams, + ) -> Result> { + let existing = P::initialize(persister) + .await + .map_err(CreateWithPersistError::Persist)?; + if !existing.is_empty() { + return Err(CreateWithPersistError::DataAlreadyExists(existing)); + } + let mut inner = + Wallet::create_with_params(params).map_err(CreateWithPersistError::Descriptor)?; + if let Some(changeset) = inner.take_staged() { + P::persist(persister, &changeset) + .await .map_err(CreateWithPersistError::Persist)?; } - Ok(wallet) + Ok(Self { + inner, + marker: PhantomData, + }) } - fn load( - conn: &mut bdk_chain::rusqlite::Transaction<'c>, - params: Self::LoadParams, - ) -> Result, Self::LoadError> { - let changeset = - crate::ChangeSet::from_sqlite(conn).map_err(LoadWithPersistError::Persist)?; - if chain::Merge::is_empty(&changeset) { - return Ok(None); + /// Load a previously [`PersistedWallet`] from the given async `persister` and `params`. + pub async fn load_async( + persister: &mut P, + params: LoadParams, + ) -> Result, LoadWithPersistError> { + let changeset = P::initialize(persister) + .await + .map_err(LoadWithPersistError::Persist)?; + Wallet::load_with_params(changeset, params) + .map(|opt| { + opt.map(|inner| PersistedWallet { + inner, + marker: PhantomData, + }) + }) + .map_err(LoadWithPersistError::InvalidChangeSet) + } + + /// Persist staged changes of wallet into an async `persister`. + /// + /// Returns whether any new changes were persisted. + /// + /// If the `persister` errors, the staged changes will not be cleared. + pub async fn persist_async<'a>(&'a mut self, persister: &mut P) -> Result { + match self.inner.staged_mut() { + Some(stage) => { + P::persist(persister, &*stage).await?; + let _ = stage.take(); + Ok(true) + } + None => Ok(false), } - Self::load_with_params(changeset, params).map_err(LoadWithPersistError::InvalidChangeSet) } +} - fn persist( - db: &mut bdk_chain::rusqlite::Transaction<'c>, - changeset: &::ChangeSet, - ) -> Result<(), Self::PersistError> { - changeset.persist_to_sqlite(db) +#[cfg(feature = "rusqlite")] +impl<'c> WalletPersister for bdk_chain::rusqlite::Transaction<'c> { + type Error = bdk_chain::rusqlite::Error; + + fn initialize(persister: &mut Self) -> Result { + ChangeSet::init_sqlite_tables(&*persister)?; + ChangeSet::from_sqlite(persister) + } + + fn persist(persister: &mut Self, changeset: &ChangeSet) -> Result<(), Self::Error> { + changeset.persist_to_sqlite(persister) } } #[cfg(feature = "rusqlite")] -impl chain::PersistWith for Wallet { - type CreateParams = crate::CreateParams; - type LoadParams = crate::LoadParams; - - type CreateError = CreateWithPersistError; - type LoadError = LoadWithPersistError; - type PersistError = bdk_chain::rusqlite::Error; - - fn create( - db: &mut bdk_chain::rusqlite::Connection, - params: Self::CreateParams, - ) -> Result { - let mut db_tx = db.transaction().map_err(CreateWithPersistError::Persist)?; - let wallet = chain::PersistWith::create(&mut db_tx, params)?; - db_tx.commit().map_err(CreateWithPersistError::Persist)?; - Ok(wallet) - } - - fn load( - db: &mut bdk_chain::rusqlite::Connection, - params: Self::LoadParams, - ) -> Result, Self::LoadError> { - let mut db_tx = db.transaction().map_err(LoadWithPersistError::Persist)?; - let wallet_opt = chain::PersistWith::load(&mut db_tx, params)?; - db_tx.commit().map_err(LoadWithPersistError::Persist)?; - Ok(wallet_opt) - } - - fn persist( - db: &mut bdk_chain::rusqlite::Connection, - changeset: &::ChangeSet, - ) -> Result<(), Self::PersistError> { - let db_tx = db.transaction()?; +impl WalletPersister for bdk_chain::rusqlite::Connection { + type Error = bdk_chain::rusqlite::Error; + + fn initialize(persister: &mut Self) -> Result { + let db_tx = persister.transaction()?; + ChangeSet::init_sqlite_tables(&db_tx)?; + let changeset = ChangeSet::from_sqlite(&db_tx)?; + db_tx.commit()?; + Ok(changeset) + } + + fn persist(persister: &mut Self, changeset: &ChangeSet) -> Result<(), Self::Error> { + let db_tx = persister.transaction()?; changeset.persist_to_sqlite(&db_tx)?; db_tx.commit() } } +/// Error for [`bdk_file_store`]'s implementation of [`WalletPersister`]. #[cfg(feature = "file_store")] -impl chain::PersistWith> for Wallet { - type CreateParams = crate::CreateParams; - type LoadParams = crate::LoadParams; - type CreateError = CreateWithPersistError; - type LoadError = - LoadWithPersistError>; - type PersistError = std::io::Error; - - fn create( - db: &mut bdk_file_store::Store, - params: Self::CreateParams, - ) -> Result { - let mut wallet = - Self::create_with_params(params).map_err(CreateWithPersistError::Descriptor)?; - if let Some(changeset) = wallet.take_staged() { - db.append_changeset(&changeset) - .map_err(CreateWithPersistError::Persist)?; +#[derive(Debug)] +pub enum FileStoreError { + /// Error when loading from the store. + Load(bdk_file_store::AggregateChangesetsError), + /// Error when writing to the store. + Write(std::io::Error), +} + +#[cfg(feature = "file_store")] +impl core::fmt::Display for FileStoreError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + use core::fmt::Display; + match self { + FileStoreError::Load(e) => Display::fmt(e, f), + FileStoreError::Write(e) => Display::fmt(e, f), } - Ok(wallet) } +} + +#[cfg(feature = "file_store")] +impl std::error::Error for FileStoreError {} + +#[cfg(feature = "file_store")] +impl WalletPersister for bdk_file_store::Store { + type Error = FileStoreError; - fn load( - db: &mut bdk_file_store::Store, - params: Self::LoadParams, - ) -> Result, Self::LoadError> { - let changeset = db + fn initialize(persister: &mut Self) -> Result { + persister .aggregate_changesets() - .map_err(LoadWithPersistError::Persist)? - .unwrap_or_default(); - Self::load_with_params(changeset, params).map_err(LoadWithPersistError::InvalidChangeSet) + .map(Option::unwrap_or_default) + .map_err(FileStoreError::Load) } - fn persist( - db: &mut bdk_file_store::Store, - changeset: &::ChangeSet, - ) -> Result<(), Self::PersistError> { - db.append_changeset(changeset) + fn persist(persister: &mut Self, changeset: &ChangeSet) -> Result<(), Self::Error> { + persister + .append_changeset(changeset) + .map_err(FileStoreError::Write) } } @@ -154,6 +354,8 @@ impl std::error::Error for LoadWithPersistError pub enum CreateWithPersistError { /// Error from persistence. Persist(E), + /// Persister already has wallet data. + DataAlreadyExists(ChangeSet), /// Occurs when the loaded changeset cannot construct [`Wallet`]. Descriptor(DescriptorError), } @@ -162,6 +364,11 @@ impl fmt::Display for CreateWithPersistError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Persist(err) => fmt::Display::fmt(err, f), + Self::DataAlreadyExists(changeset) => write!( + f, + "Cannot create wallet in persister which already contains wallet data: {:?}", + changeset + ), Self::Descriptor(err) => fmt::Display::fmt(&err, f), } } diff --git a/crates/wallet/tests/wallet.rs b/crates/wallet/tests/wallet.rs index d41544a1d..c530e779c 100644 --- a/crates/wallet/tests/wallet.rs +++ b/crates/wallet/tests/wallet.rs @@ -5,15 +5,15 @@ use std::str::FromStr; use anyhow::Context; use assert_matches::assert_matches; +use bdk_chain::COINBASE_MATURITY; use bdk_chain::{BlockId, ConfirmationTime}; -use bdk_chain::{PersistWith, COINBASE_MATURITY}; use bdk_wallet::coin_selection::{self, LargestFirstCoinSelection}; use bdk_wallet::descriptor::{calc_checksum, DescriptorError, IntoWalletDescriptor}; use bdk_wallet::error::CreateTxError; use bdk_wallet::psbt::PsbtUtils; use bdk_wallet::signer::{SignOptions, SignerError}; use bdk_wallet::tx_builder::AddForeignUtxoError; -use bdk_wallet::{AddressInfo, Balance, CreateParams, LoadParams, Wallet}; +use bdk_wallet::{AddressInfo, Balance, ChangeSet, Wallet, WalletPersister}; use bdk_wallet::{KeychainKind, LoadError, LoadMismatch, LoadWithPersistError}; use bitcoin::constants::ChainHash; use bitcoin::hashes::Hash; @@ -111,10 +111,8 @@ fn wallet_is_persisted() -> anyhow::Result<()> { where CreateDb: Fn(&Path) -> anyhow::Result, OpenDb: Fn(&Path) -> anyhow::Result, - Wallet: PersistWith, - >::CreateError: std::error::Error + Send + Sync + 'static, - >::LoadError: std::error::Error + Send + Sync + 'static, - >::PersistError: std::error::Error + Send + Sync + 'static, + Db: WalletPersister, + Db::Error: std::error::Error + Send + Sync + 'static, { let temp_dir = tempfile::tempdir().expect("must create tempdir"); let file_path = temp_dir.path().join(filename); @@ -188,7 +186,7 @@ fn wallet_is_persisted() -> anyhow::Result<()> { #[test] fn wallet_load_checks() -> anyhow::Result<()> { - fn run( + fn run( filename: &str, create_db: CreateDb, open_db: OpenDb, @@ -196,15 +194,8 @@ fn wallet_load_checks() -> anyhow::Result<()> { where CreateDb: Fn(&Path) -> anyhow::Result, OpenDb: Fn(&Path) -> anyhow::Result, - Wallet: PersistWith< - Db, - CreateParams = CreateParams, - LoadParams = LoadParams, - LoadError = LoadWithPersistError, - >, - >::CreateError: std::error::Error + Send + Sync + 'static, - >::LoadError: std::error::Error + Send + Sync + 'static, - >::PersistError: std::error::Error + Send + Sync + 'static, + Db: WalletPersister + std::fmt::Debug, + Db::Error: std::error::Error + Send + Sync + 'static, { let temp_dir = tempfile::tempdir().expect("must create tempdir"); let file_path = temp_dir.path().join(filename); @@ -258,8 +249,12 @@ fn wallet_load_checks() -> anyhow::Result<()> { run( "store.db", - |path| Ok(bdk_file_store::Store::create_new(DB_MAGIC, path)?), - |path| Ok(bdk_file_store::Store::open(DB_MAGIC, path)?), + |path| { + Ok(bdk_file_store::Store::::create_new( + DB_MAGIC, path, + )?) + }, + |path| Ok(bdk_file_store::Store::::open(DB_MAGIC, path)?), )?; run( "store.sqlite", @@ -280,7 +275,7 @@ fn single_descriptor_wallet_persist_and_recover() { let mut db = rusqlite::Connection::open(db_path).unwrap(); let desc = get_test_tr_single_sig_xprv(); - let mut wallet = CreateParams::new_single(desc) + let mut wallet = Wallet::create_single(desc) .network(Network::Testnet) .create_wallet(&mut db) .unwrap(); @@ -4174,7 +4169,7 @@ fn test_insert_tx_balance_and_utxos() { #[test] fn single_descriptor_wallet_can_create_tx_and_receive_change() { // create single descriptor wallet and fund it - let mut wallet = CreateParams::new_single(get_test_tr_single_sig_xprv()) + let mut wallet = Wallet::create_single(get_test_tr_single_sig_xprv()) .network(Network::Testnet) .create_wallet_no_persist() .unwrap();