diff --git a/payjoin-cli/src/app/v2.rs b/payjoin-cli/src/app/v2.rs index dad11fb80..ac827e5f3 100644 --- a/payjoin-cli/src/app/v2.rs +++ b/payjoin-cli/src/app/v2.rs @@ -4,7 +4,7 @@ use anyhow::{anyhow, Context, Result}; use payjoin::bitcoin::consensus::encode::serialize_hex; use payjoin::bitcoin::psbt::Psbt; use payjoin::bitcoin::{Amount, FeeRate}; -use payjoin::receive::v2::{Receiver, UncheckedProposal}; +use payjoin::receive::v2::{EphemeralReceiver, Receiver, UncheckedProposal}; use payjoin::receive::{Error, ImplementationError, ReplyableError}; use payjoin::send::v2::{Sender, SenderBuilder}; use payjoin::Uri; @@ -52,10 +52,14 @@ impl AppTrait for App { Some(send_session) => send_session, None => { let psbt = self.create_original_psbt(&uri, fee_rate)?; - let mut req_ctx = SenderBuilder::new(psbt, uri.clone()) + let req_ctx = SenderBuilder::new(psbt, uri.clone()) .build_recommended(fee_rate) - .with_context(|| "Failed to build payjoin request")?; - self.db.insert_send_session(&mut req_ctx, url)?; + .with_context(|| "Failed to build payjoin request")? + .persist(|key, sender| { + let mut sender = sender.clone(); + self.db.insert_send_session(&mut sender, key)?; + Ok(()) + })?; req_ctx } }; @@ -65,13 +69,18 @@ impl AppTrait for App { async fn receive_payjoin(&self, amount: Amount) -> Result<()> { let address = self.wallet().get_new_address()?; let ohttp_keys = unwrap_ohttp_keys_or_else_fetch(&self.config).await?; - let session = Receiver::new( + let ephemeral_receiver = EphemeralReceiver::new( address, self.config.v2()?.pj_directory.clone(), ohttp_keys.clone(), None, )?; - self.db.insert_recv_session(session.clone())?; + let session = ephemeral_receiver.persist(|key, r| { + self.db + .insert_recv_session(key, r.clone()) + .map_err(|e| ReplyableError::Implementation(Box::new(e)))?; + Ok(()) + })?; self.spawn_payjoin_receiver(session, Some(amount)).await } diff --git a/payjoin-cli/src/db/v2.rs b/payjoin-cli/src/db/v2.rs index 136c8894e..cbc4d3947 100644 --- a/payjoin-cli/src/db/v2.rs +++ b/payjoin-cli/src/db/v2.rs @@ -7,11 +7,10 @@ use url::Url; use super::*; impl Database { - pub(crate) fn insert_recv_session(&self, session: Receiver) -> Result<()> { + pub(crate) fn insert_recv_session(&self, key: &[u8], session: Receiver) -> Result<()> { let recv_tree = self.0.open_tree("recv_sessions")?; - let key = &session.id(); let value = serde_json::to_string(&session).map_err(Error::Serialize)?; - recv_tree.insert(key.as_slice(), IVec::from(value.as_str()))?; + recv_tree.insert(key, IVec::from(value.as_str()))?; recv_tree.flush()?; Ok(()) } @@ -34,10 +33,10 @@ impl Database { Ok(()) } - pub(crate) fn insert_send_session(&self, session: &mut Sender, pj_url: &Url) -> Result<()> { + pub(crate) fn insert_send_session(&self, session: &mut Sender, key: &[u8]) -> Result<()> { let send_tree: Tree = self.0.open_tree("send_sessions")?; let value = serde_json::to_string(session).map_err(Error::Serialize)?; - send_tree.insert(pj_url.to_string(), IVec::from(value.as_str()))?; + send_tree.insert(key, IVec::from(value.as_str()))?; send_tree.flush()?; Ok(()) } diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 916acc8a5..99179d7df 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -60,6 +60,33 @@ pub struct Receiver { context: SessionContext, } +/// A wrapper around the receiver session. The receiver session is accessible only after it has been persisted via [`EphemeralReceiver::persist`] +pub struct EphemeralReceiver(Receiver); + +impl EphemeralReceiver { + pub fn new( + address: Address, + directory: impl IntoUrl, + ohttp_keys: OhttpKeys, + expire_after: Option, + ) -> Result { + Ok(EphemeralReceiver(Receiver::new(address, directory, ohttp_keys, expire_after)?)) + } + + /// Persist the receiver session to the database. Implementation details are left to the caller. + /// The closure given should accept a slice to be used a key in a key-value store and the receiver which is deserializable. + pub fn persist( + &self, + persist: impl Fn(&[u8], &Receiver) -> Result<(), ImplementationError>, + ) -> Result { + let receiver = self.0.clone(); + let short_id = id(&receiver.context.s); + let id = short_id.0.as_slice(); + persist(id, &receiver).map_err(ReplyableError::Implementation)?; + Ok(receiver) + } +} + impl Receiver { /// Creates a new `Receiver` with the provided parameters. /// @@ -74,7 +101,7 @@ impl Receiver { /// /// # References /// - [BIP 77: Payjoin Version 2: Serverless Payjoin](https://github.com/bitcoin/bips/pull/1483) - pub fn new( + pub(crate) fn new( address: Address, directory: impl IntoUrl, ohttp_keys: OhttpKeys, diff --git a/payjoin/src/send/multiparty/mod.rs b/payjoin/src/send/multiparty/mod.rs index 40bf2fabd..d064494c2 100644 --- a/payjoin/src/send/multiparty/mod.rs +++ b/payjoin/src/send/multiparty/mod.rs @@ -24,7 +24,8 @@ impl<'a> SenderBuilder<'a> { pub fn new(psbt: Psbt, uri: PjUri<'a>) -> Self { Self(v2::SenderBuilder::new(psbt, uri)) } pub fn build_recommended(self, min_fee_rate: FeeRate) -> Result { let v2 = v2::SenderBuilder::new(self.0 .0.psbt, self.0 .0.uri) - .build_recommended(min_fee_rate)?; + .build_recommended(min_fee_rate)? + .persist(|_id, _sender| Ok(()))?; Ok(Sender(v2)) } } diff --git a/payjoin/src/send/v2/mod.rs b/payjoin/src/send/v2/mod.rs index e9cc809f0..76c080165 100644 --- a/payjoin/src/send/v2/mod.rs +++ b/payjoin/src/send/v2/mod.rs @@ -64,11 +64,14 @@ impl<'a> SenderBuilder<'a> { // The minfeerate parameter is set if the contribution is available in change. // // This method fails if no recommendation can be made or if the PSBT is malformed. - pub fn build_recommended(self, min_fee_rate: FeeRate) -> Result { - Ok(Sender { + pub fn build_recommended( + self, + min_fee_rate: FeeRate, + ) -> Result { + Ok(EphemeralSender(Sender { v1: self.0.build_recommended(min_fee_rate)?, reply_key: HpkeKeyPair::gen_keypair().0, - }) + })) } /// Offer the receiver contribution to pay for his input. @@ -90,8 +93,8 @@ impl<'a> SenderBuilder<'a> { change_index: Option, min_fee_rate: FeeRate, clamp_fee_contribution: bool, - ) -> Result { - Ok(Sender { + ) -> Result { + Ok(EphemeralSender(Sender { v1: self.0.build_with_additional_fee( max_fee_contribution, change_index, @@ -99,7 +102,7 @@ impl<'a> SenderBuilder<'a> { clamp_fee_contribution, )?, reply_key: HpkeKeyPair::gen_keypair().0, - }) + })) } /// Perform Payjoin without incentivizing the payee to cooperate. @@ -109,11 +112,29 @@ impl<'a> SenderBuilder<'a> { pub fn build_non_incentivizing( self, min_fee_rate: FeeRate, - ) -> Result { - Ok(Sender { + ) -> Result { + Ok(EphemeralSender(Sender { v1: self.0.build_non_incentivizing(min_fee_rate)?, reply_key: HpkeKeyPair::gen_keypair().0, - }) + })) + } +} + +pub struct EphemeralSender(Sender); + +impl EphemeralSender { + pub fn new(sender: Sender) -> EphemeralSender { EphemeralSender(sender) } + + pub fn persist( + &self, + persist: impl Fn(&[u8], &Sender) -> Result<(), Box>, + ) -> Result { + let sender = self.0.clone(); + let pj_uri = sender.endpoint().to_string(); + let id = pj_uri.as_bytes(); + // TODO(armins): handle unwrap + persist(id, &sender).unwrap(); + Ok(sender) } } diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index be9156c96..91e74718e 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -169,7 +169,7 @@ mod integration { use bitcoin::Address; use http::StatusCode; - use payjoin::receive::v2::{PayjoinProposal, Receiver, UncheckedProposal}; + use payjoin::receive::v2::{EphemeralReceiver, PayjoinProposal, UncheckedProposal}; use payjoin::send::v2::SenderBuilder; use payjoin::{OhttpKeys, PjUri, UriExt}; use payjoin_test_utils::{BoxSendSyncError, TestServices}; @@ -204,7 +204,8 @@ mod integration { let mock_address = Address::from_str("tb1q6d3a2w975yny0asuvd9a67ner4nks58ff0q8g4")? .assume_checked(); let mut bad_initializer = - Receiver::new(mock_address, directory, bad_ohttp_keys, None)?; + EphemeralReceiver::new(mock_address, directory, bad_ohttp_keys, None)? + .persist(|_, _| Ok(()))?; let (req, _ctx) = bad_initializer.extract_req(&mock_ohttp_relay)?; agent.post(req.url).body(req.body).send().await.map_err(|e| e.into()) } @@ -234,12 +235,13 @@ mod integration { // Inside the Receiver: let address = receiver.get_new_address(None, None)?.assume_checked(); // test session with expiry in the past - let mut expired_receiver = Receiver::new( + let mut expired_receiver = EphemeralReceiver::new( address.clone(), directory.clone(), ohttp_keys.clone(), Some(Duration::from_secs(0)), - )?; + )? + .persist(|_key, _receiver| Ok(()))?; match expired_receiver.extract_req(&ohttp_relay) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), @@ -251,7 +253,8 @@ mod integration { let psbt = build_original_psbt(&sender, &expired_receiver.pj_uri())?; // Test that an expired pj_url errors let expired_req_ctx = SenderBuilder::new(psbt, expired_receiver.pj_uri()) - .build_non_incentivizing(FeeRate::BROADCAST_MIN)?; + .build_non_incentivizing(FeeRate::BROADCAST_MIN)? + .persist(|_key, _sender| Ok(()))?; match expired_req_ctx.extract_v2(directory.to_owned()) { // Internal error types are private, so check against a string Err(err) => assert!(err.to_string().contains("expired")), @@ -286,8 +289,13 @@ mod integration { let address = receiver.get_new_address(None, None)?.assume_checked(); // test session with expiry in the future - let mut session = - Receiver::new(address.clone(), directory.clone(), ohttp_keys.clone(), None)?; + let mut session = EphemeralReceiver::new( + address.clone(), + directory.clone(), + ohttp_keys.clone(), + None, + )? + .persist(|_key, _receiver| Ok(()))?; println!("session: {:#?}", &session); // Poll receive request let mock_ohttp_relay = services.ohttp_gateway_url(); @@ -309,7 +317,8 @@ mod integration { .map_err(|e| e.to_string())?; let psbt = build_sweep_psbt(&sender, &pj_uri)?; let req_ctx = SenderBuilder::new(psbt.clone(), pj_uri.clone()) - .build_recommended(FeeRate::BROADCAST_MIN)?; + .build_recommended(FeeRate::BROADCAST_MIN)? + .persist(|_key, _sender| Ok(()))?; let (Request { url, body, content_type, .. }, send_ctx) = req_ctx.extract_v2(mock_ohttp_relay.to_owned())?; let response = agent @@ -399,7 +408,8 @@ mod integration { .map_err(|e| e.to_string())?; let psbt = build_original_psbt(&sender, &pj_uri)?; let req_ctx = SenderBuilder::new(psbt.clone(), pj_uri.clone()) - .build_recommended(FeeRate::BROADCAST_MIN)?; + .build_recommended(FeeRate::BROADCAST_MIN)? + .persist(|_key, _sender| Ok(()))?; let (req, ctx) = req_ctx.extract_v1()?; let headers = HeaderMock::new(&req.body, req.content_type); @@ -449,7 +459,8 @@ mod integration { let address = receiver.get_new_address(None, None)?.assume_checked(); let mut session = - Receiver::new(address, directory.clone(), ohttp_keys.clone(), None)?; + EphemeralReceiver::new(address, directory.clone(), ohttp_keys.clone(), None)? + .persist(|_key, _receiver| Ok(()))?; // ********************** // Inside the V1 Sender: @@ -468,6 +479,7 @@ mod integration { FeeRate::ZERO, false, )? + .persist(|_key, _sender| Ok(()))? .extract_v1()?; log::info!("send fallback v1 to offline receiver fail"); let res = agent @@ -663,7 +675,7 @@ mod integration { #[cfg(feature = "_multiparty")] mod multiparty { use bitcoin::ScriptBuf; - use payjoin::receive::v2::Receiver; + use payjoin::receive::v2::{EphemeralReceiver, Receiver}; use payjoin::send::multiparty::{ GetContext as MultiPartyGetContext, SenderBuilder as MultiPartySenderBuilder, }; @@ -710,12 +722,13 @@ mod integration { // Senders will generate a sweep psbt and send PSBT to receiver subdir for sender in senders.iter() { let address = receiver.get_new_address(None, None)?.assume_checked(); - let receiver_session = Receiver::new( + let receiver_session = EphemeralReceiver::new( address.clone(), directory.clone(), ohttp_keys.clone(), None, - )?; + )? + .persist(|_key, _receiver| Ok(()))?; let pj_uri = receiver_session.pj_uri(); let psbt = build_sweep_psbt(sender, &pj_uri)?; let sender_ctx = MultiPartySenderBuilder::new(psbt.clone(), pj_uri.clone())