diff --git a/payjoin-cli/src/app/v1.rs b/payjoin-cli/src/app/v1.rs index 7f2666ad6..a0ef11250 100644 --- a/payjoin-cli/src/app/v1.rs +++ b/payjoin-cli/src/app/v1.rs @@ -139,8 +139,11 @@ impl App { let pj_part = payjoin::Url::parse(pj_part) .map_err(|e| anyhow!("Failed to parse pj_endpoint: {}", e))?; - let mut pj_uri = - payjoin::receive::v1::build_v1_pj_uri(&pj_receiver_address, &pj_part, false)?; + let mut pj_uri = payjoin::receive::v1::build_v1_pj_uri( + &pj_receiver_address, + &pj_part, + payjoin::OutputSubstitution::Enabled, + )?; pj_uri.amount = Some(amount); Ok(pj_uri.to_string()) diff --git a/payjoin/src/lib.rs b/payjoin/src/lib.rs index db78e0dbe..5662a017a 100644 --- a/payjoin/src/lib.rs +++ b/payjoin/src/lib.rs @@ -54,8 +54,11 @@ mod request; #[cfg(feature = "_core")] pub use request::*; #[cfg(feature = "_core")] +pub(crate) mod output_substitution; +#[cfg(feature = "v1")] +pub use output_substitution::OutputSubstitution; +#[cfg(feature = "_core")] mod uri; - #[cfg(feature = "_core")] pub use into_url::{Error as IntoUrlError, IntoUrl}; #[cfg(feature = "_core")] diff --git a/payjoin/src/output_substitution.rs b/payjoin/src/output_substitution.rs new file mode 100644 index 000000000..e6ec9090e --- /dev/null +++ b/payjoin/src/output_substitution.rs @@ -0,0 +1,20 @@ +/// Whether the receiver is allowed to substitute original outputs or not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "v2", derive(serde::Serialize, serde::Deserialize))] +pub enum OutputSubstitution { + Enabled, + Disabled, +} + +impl OutputSubstitution { + /// Combine two output substitution flags. + /// + /// If both are enabled, the result is enabled. + /// If one is disabled, the result is disabled. + pub(crate) fn combine(self, other: Self) -> Self { + match (self, other) { + (Self::Enabled, Self::Enabled) => Self::Enabled, + _ => Self::Disabled, + } + } +} diff --git a/payjoin/src/receive/optional_parameters.rs b/payjoin/src/receive/optional_parameters.rs index e777ab8ee..c3d3257ce 100644 --- a/payjoin/src/receive/optional_parameters.rs +++ b/payjoin/src/receive/optional_parameters.rs @@ -4,12 +4,14 @@ use std::fmt; use bitcoin::FeeRate; use log::warn; +use crate::output_substitution::OutputSubstitution; + #[derive(Debug, Clone)] pub(crate) struct Params { // version pub v: usize, // disableoutputsubstitution - pub disable_output_substitution: bool, + pub output_substitution: OutputSubstitution, // maxadditionalfeecontribution, additionalfeeoutputindex pub additional_fee_contribution: Option<(bitcoin::Amount, usize)>, // minfeerate @@ -23,7 +25,7 @@ impl Default for Params { fn default() -> Self { Params { v: 1, - disable_output_substitution: false, + output_substitution: OutputSubstitution::Enabled, additional_fee_contribution: None, min_fee_rate: FeeRate::BROADCAST_MIN, #[cfg(feature = "_multiparty")] @@ -88,7 +90,11 @@ impl Params { Err(_) => return Err(Error::FeeRate), }, ("disableoutputsubstitution", v) => - params.disable_output_substitution = v == "true", + params.output_substitution = if v == "true" { + OutputSubstitution::Disabled + } else { + OutputSubstitution::Enabled + }, #[cfg(feature = "_multiparty")] ("optimisticmerge", v) => params.optimistic_merge = v == "true", _ => (), diff --git a/payjoin/src/receive/v1/exclusive/mod.rs b/payjoin/src/receive/v1/exclusive/mod.rs index 0fc3bd4cb..1f498e147 100644 --- a/payjoin/src/receive/v1/exclusive/mod.rs +++ b/payjoin/src/receive/v1/exclusive/mod.rs @@ -16,10 +16,9 @@ pub trait Headers { pub fn build_v1_pj_uri<'a>( address: &bitcoin::Address, endpoint: impl IntoUrl, - disable_output_substitution: bool, + output_substitution: OutputSubstitution, ) -> Result, crate::into_url::Error> { - let extras = - crate::uri::PayjoinExtras { endpoint: endpoint.into_url()?, disable_output_substitution }; + let extras = crate::uri::PayjoinExtras { endpoint: endpoint.into_url()?, output_substitution }; Ok(bitcoin_uri::Uri::with_extras(address.clone(), extras)) } diff --git a/payjoin/src/receive/v1/mod.rs b/payjoin/src/receive/v1/mod.rs index 6ccaa58e5..9208cb848 100644 --- a/payjoin/src/receive/v1/mod.rs +++ b/payjoin/src/receive/v1/mod.rs @@ -39,6 +39,7 @@ use super::optional_parameters::Params; use super::{ ImplementationError, InputPair, OutputSubstitutionError, ReplyableError, SelectionError, }; +use crate::output_substitution::OutputSubstitution; use crate::psbt::PsbtExt; use crate::receive::InternalPayloadError; @@ -264,9 +265,8 @@ pub struct WantsOutputs { } impl WantsOutputs { - pub fn is_output_substitution_disabled(&self) -> bool { - self.params.disable_output_substitution - } + /// Whether the receiver is allowed to substitute original outputs or not. + pub fn output_substitution(&self) -> OutputSubstitution { self.params.output_substitution } /// Substitute the receiver output script with the provided script. pub fn substitute_receiver_script( @@ -306,7 +306,7 @@ impl WantsOutputs { // Select an output with the same address if one was provided Some(pos) => { let txo = replacement_outputs.swap_remove(pos); - if self.params.disable_output_substitution + if self.output_substitution() == OutputSubstitution::Disabled && txo.value < original_output.value { return Err( @@ -317,7 +317,7 @@ impl WantsOutputs { } // Otherwise randomly select one of the replacement outputs None => { - if self.params.disable_output_substitution { + if self.output_substitution() == OutputSubstitution::Disabled { return Err( InternalOutputSubstitutionError::ScriptPubKeyChangedWhenDisabled .into(), @@ -708,7 +708,7 @@ impl ProvisionalProposal { self.payjoin_psbt.inputs[i].tap_key_sig = None; } - PayjoinProposal { payjoin_psbt: self.payjoin_psbt, params: self.params } + PayjoinProposal { payjoin_psbt: self.payjoin_psbt } } /// Return the indexes of the sender inputs @@ -763,7 +763,6 @@ impl ProvisionalProposal { #[derive(Debug, Clone)] pub struct PayjoinProposal { payjoin_psbt: Psbt, - params: Params, } impl PayjoinProposal { @@ -771,10 +770,6 @@ impl PayjoinProposal { self.payjoin_psbt.unsigned_tx.input.iter().map(|input| &input.previous_output) } - pub fn is_output_substitution_disabled(&self) -> bool { - self.params.disable_output_substitution - } - pub fn psbt(&self) -> &Psbt { &self.payjoin_psbt } } @@ -949,8 +944,7 @@ pub(crate) mod test { #[test] fn test_pjos_disabled() { let mut proposal = proposal_from_test_vector().unwrap(); - // Specify outputsubstitution is disabled - proposal.params.disable_output_substitution = true; + proposal.params.output_substitution = OutputSubstitution::Disabled; let wants_outputs = wants_outputs_from_test_vector(proposal).unwrap(); let output_value = diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index de5d3c9e5..050d89aa1 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -18,6 +18,7 @@ use super::{ }; use crate::hpke::{decrypt_message_a, encrypt_message_b, HpkeKeyPair, HpkePublicKey}; use crate::ohttp::{ohttp_decapsulate, ohttp_encapsulate, OhttpEncapsulationError, OhttpKeys}; +use crate::output_substitution::OutputSubstitution; use crate::receive::{parse_payload, InputPair}; use crate::uri::ShortId; use crate::{IntoUrl, IntoUrlError, Request}; @@ -192,7 +193,7 @@ impl Receiver { // // see: https://github.com/bitcoin/bips/blob/master/bip-0078.mediawiki#unsecured-payjoin-server if params.v == 1 { - params.disable_output_substitution = true; + params.output_substitution = OutputSubstitution::Disabled; // Additionally V1 sessions never have an optimistic merge opportunity #[cfg(feature = "_multiparty")] @@ -212,7 +213,8 @@ impl Receiver { pj.set_receiver_pubkey(self.context.s.public_key().clone()); pj.set_ohttp(self.context.ohttp_keys.clone()); pj.set_exp(self.context.expiry); - let extras = PayjoinExtras { endpoint: pj, disable_output_substitution: false }; + let extras = + PayjoinExtras { endpoint: pj, output_substitution: OutputSubstitution::Enabled }; bitcoin_uri::Uri::with_extras(self.context.address.clone(), extras) } @@ -385,9 +387,8 @@ pub struct WantsOutputs { } impl WantsOutputs { - pub fn is_output_substitution_disabled(&self) -> bool { - self.v1.is_output_substitution_disabled() - } + /// Whether the receiver is allowed to substitute original outputs or not. + pub fn output_substitution(&self) -> OutputSubstitution { self.v1.output_substitution() } /// Substitute the receiver output script with the provided script. pub fn substitute_receiver_script( @@ -503,10 +504,6 @@ impl PayjoinProposal { self.v1.utxos_to_be_locked() } - pub fn is_output_substitution_disabled(&self) -> bool { - self.v1.is_output_substitution_disabled() - } - pub fn psbt(&self) -> &Psbt { self.v1.psbt() } pub fn extract_v2_req( @@ -652,6 +649,6 @@ mod test { fn test_v2_pj_uri() { let uri = Receiver { context: SHARED_CONTEXT.clone() }.pj_uri(); assert_ne!(uri.extras.endpoint, EXAMPLE_URL.clone()); - assert!(!uri.extras.disable_output_substitution); + assert_eq!(uri.extras.output_substitution, OutputSubstitution::Enabled); } } diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index dc85923ca..83a7188c8 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -17,6 +17,7 @@ pub use error::{BuildSenderError, ResponseError, ValidationError, WellKnownError pub(crate) use error::{InternalBuildSenderError, InternalProposalError, InternalValidationError}; use url::Url; +use crate::output_substitution::OutputSubstitution; use crate::psbt::PsbtExt; // See usize casts @@ -51,7 +52,7 @@ pub(crate) struct AdditionalFeeContribution { #[derive(Debug, Clone)] pub struct PsbtContext { original_psbt: Psbt, - disable_output_substitution: bool, + output_substitution: OutputSubstitution, fee_contribution: Option, min_fee_rate: FeeRate, payee: ScriptBuf, @@ -253,7 +254,7 @@ impl PsbtContext { if original_output.script_pubkey == self.payee => { ensure!( - !self.disable_output_substitution + self.output_substitution == OutputSubstitution::Enabled || (proposed_txout.script_pubkey == original_output.script_pubkey && proposed_txout.value >= original_output.value), DisallowedOutputSubstitution @@ -414,14 +415,14 @@ fn determine_fee_contribution( fn serialize_url( endpoint: Url, - disable_output_substitution: bool, + output_substitution: OutputSubstitution, fee_contribution: Option, min_fee_rate: FeeRate, version: &str, ) -> Url { let mut url = endpoint; url.query_pairs_mut().append_pair("v", version); - if disable_output_substitution { + if output_substitution == OutputSubstitution::Disabled { url.query_pairs_mut().append_pair("disableoutputsubstitution", "true"); } if let Some(AdditionalFeeContribution { max_amount, vout }) = fee_contribution { @@ -449,6 +450,7 @@ mod test { use super::{ check_single_payee, clear_unneeded_fields, determine_fee_contribution, serialize_url, }; + use crate::output_substitution::OutputSubstitution; use crate::psbt::PsbtExt; use crate::send::{AdditionalFeeContribution, InternalBuildSenderError, InternalProposalError}; @@ -456,7 +458,7 @@ mod test { let payee = PARSED_ORIGINAL_PSBT.unsigned_tx.output[1].script_pubkey.clone(); Ok(super::PsbtContext { original_psbt: PARSED_ORIGINAL_PSBT.clone(), - disable_output_substitution: false, + output_substitution: OutputSubstitution::Enabled, fee_contribution: Some(AdditionalFeeContribution { max_amount: bitcoin::Amount::from_sat(182), vout: 0, @@ -641,10 +643,22 @@ mod test { #[test] fn test_disable_output_substitution_query_param() -> Result<(), BoxError> { - let url = serialize_url(Url::parse("http://localhost")?, true, None, FeeRate::ZERO, "2"); + let url = serialize_url( + Url::parse("http://localhost")?, + OutputSubstitution::Disabled, + None, + FeeRate::ZERO, + "2", + ); assert_eq!(url, Url::parse("http://localhost?v=2&disableoutputsubstitution=true")?); - let url = serialize_url(Url::parse("http://localhost")?, false, None, FeeRate::ZERO, "2"); + let url = serialize_url( + Url::parse("http://localhost")?, + OutputSubstitution::Enabled, + None, + FeeRate::ZERO, + "2", + ); assert_eq!(url, Url::parse("http://localhost?v=2")?); Ok(()) } diff --git a/payjoin/src/send/multiparty/mod.rs b/payjoin/src/send/multiparty/mod.rs index b70e51255..e2b5b30b7 100644 --- a/payjoin/src/send/multiparty/mod.rs +++ b/payjoin/src/send/multiparty/mod.rs @@ -10,6 +10,7 @@ use super::v2::{self, extract_request, EncapsulationError, HpkeContext}; use super::{serialize_url, AdditionalFeeContribution, BuildSenderError, InternalResult}; use crate::hpke::decrypt_message_b; use crate::ohttp::ohttp_decapsulate; +use crate::output_substitution::OutputSubstitution; use crate::receive::ImplementationError; use crate::send::v2::V2PostContext; use crate::uri::UrlExt; @@ -48,7 +49,7 @@ impl Sender { .map_err(|_| InternalCreateRequestError::MissingOhttpConfig)?; let body = serialize_v2_body( &self.0.v1.psbt, - self.0.v1.disable_output_substitution, + self.0.v1.output_substitution, self.0.v1.fee_contribution, self.0.v1.min_fee_rate, )?; @@ -65,7 +66,7 @@ impl Sender { endpoint: self.0.endpoint().clone(), psbt_ctx: crate::send::PsbtContext { original_psbt: self.0.v1.psbt.clone(), - disable_output_substitution: self.0.v1.disable_output_substitution, + output_substitution: self.0.v1.output_substitution, fee_contribution: self.0.v1.fee_contribution, payee: self.0.v1.payee.clone(), min_fee_rate: self.0.v1.min_fee_rate, @@ -79,13 +80,13 @@ impl Sender { fn serialize_v2_body( psbt: &Psbt, - disable_output_substitution: bool, + output_substitution: OutputSubstitution, fee_contribution: Option, min_fee_rate: FeeRate, ) -> Result, CreateRequestError> { let mut url = serialize_url( Url::parse("http://localhost").unwrap(), - disable_output_substitution, + output_substitution, fee_contribution, min_fee_rate, "2", @@ -175,7 +176,12 @@ impl FinalizeContext { ohttp_relay: Url, ) -> Result<(Request, ohttp::ClientResponse), CreateRequestError> { let reply_key = self.hpke_ctx.reply_pair.secret_key(); - let body = serialize_v2_body(&self.psbt, false, None, FeeRate::BROADCAST_MIN)?; + let body = serialize_v2_body( + &self.psbt, + OutputSubstitution::Disabled, + None, + FeeRate::BROADCAST_MIN, + )?; let mut ohttp_keys = self .directory_url .ohttp() @@ -237,17 +243,29 @@ mod test { use payjoin_test_utils::BoxError; use url::Url; + use crate::output_substitution::OutputSubstitution; use crate::send::multiparty::append_optimisitic_merge_query_param; use crate::send::serialize_url; #[test] fn test_optimistic_merge_query_param() -> Result<(), BoxError> { - let mut url = - serialize_url(Url::parse("http://localhost")?, false, None, FeeRate::ZERO, "2"); + let mut url = serialize_url( + Url::parse("http://localhost")?, + OutputSubstitution::Enabled, + None, + FeeRate::ZERO, + "2", + ); append_optimisitic_merge_query_param(&mut url); assert_eq!(url, Url::parse("http://localhost?v=2&optimisticmerge=true")?); - let url = serialize_url(Url::parse("http://localhost")?, false, None, FeeRate::ZERO, "2"); + let url = serialize_url( + Url::parse("http://localhost")?, + OutputSubstitution::Enabled, + None, + FeeRate::ZERO, + "2", + ); assert_eq!(url, Url::parse("http://localhost?v=2")?); Ok(()) diff --git a/payjoin/src/send/v1.rs b/payjoin/src/send/v1.rs index e7679f646..5a763b72b 100644 --- a/payjoin/src/send/v1.rs +++ b/payjoin/src/send/v1.rs @@ -27,15 +27,16 @@ use error::{BuildSenderError, InternalBuildSenderError}; use url::Url; use super::*; +pub use crate::output_substitution::OutputSubstitution; use crate::psbt::PsbtExt; use crate::request::Request; -use crate::PjUri; +pub use crate::PjUri; #[derive(Clone)] pub struct SenderBuilder<'a> { pub(crate) psbt: Psbt, pub(crate) uri: PjUri<'a>, - pub(crate) disable_output_substitution: bool, + pub(crate) output_substitution: OutputSubstitution, pub(crate) fee_contribution: Option<(bitcoin::Amount, Option)>, /// Decreases the fee contribution instead of erroring. /// @@ -56,7 +57,7 @@ impl<'a> SenderBuilder<'a> { psbt, uri, // Sender's optional parameters - disable_output_substitution: false, + output_substitution: OutputSubstitution::Enabled, fee_contribution: None, clamp_fee_contribution: false, min_fee_rate: FeeRate::ZERO, @@ -69,8 +70,8 @@ impl<'a> SenderBuilder<'a> { /// It is generally **not** recommended to set this as it may prevent the receiver from /// doing advanced operations such as opening LN channels and it also guarantees the /// receiver will **not** reward the sender with a discount. - pub fn always_disable_output_substitution(mut self, disable: bool) -> Self { - self.disable_output_substitution = disable; + pub fn always_disable_output_substitution(mut self) -> Self { + self.output_substitution = OutputSubstitution::Disabled; self } @@ -185,8 +186,8 @@ impl<'a> SenderBuilder<'a> { self.psbt.validate().map_err(InternalBuildSenderError::InconsistentOriginalPsbt)?; psbt.validate_input_utxos().map_err(InternalBuildSenderError::InvalidOriginalInput)?; let endpoint = self.uri.extras.endpoint.clone(); - let disable_output_substitution = - self.uri.extras.disable_output_substitution || self.disable_output_substitution; + let output_substitution = + self.uri.extras.output_substitution.combine(self.output_substitution); let payee = self.uri.address.script_pubkey(); check_single_payee(&psbt, &payee, self.uri.amount)?; @@ -201,7 +202,7 @@ impl<'a> SenderBuilder<'a> { Ok(Sender { psbt, endpoint, - disable_output_substitution, + output_substitution, fee_contribution, payee, min_fee_rate: self.min_fee_rate, @@ -216,8 +217,8 @@ pub struct Sender { pub(crate) psbt: Psbt, /// The payjoin directory subdirectory to send the request to. pub(crate) endpoint: Url, - /// Disallow receiver to substitute original outputs. - pub(crate) disable_output_substitution: bool, + /// Whether the receiver is allowed to substitute original outputs. + pub(crate) output_substitution: OutputSubstitution, /// (maxadditionalfeecontribution, additionalfeeoutputindex) pub(crate) fee_contribution: Option, pub(crate) min_fee_rate: FeeRate, @@ -230,7 +231,7 @@ impl Sender { pub fn extract_v1(&self) -> (Request, V1Context) { let url = serialize_url( self.endpoint.clone(), - self.disable_output_substitution, + self.output_substitution, self.fee_contribution, self.min_fee_rate, "1", // payjoin version @@ -241,7 +242,7 @@ impl Sender { V1Context { psbt_context: PsbtContext { original_psbt: self.psbt.clone(), - disable_output_substitution: self.disable_output_substitution, + output_substitution: self.output_substitution, fee_contribution: self.fee_contribution, payee: self.payee.clone(), min_fee_rate: self.min_fee_rate, diff --git a/payjoin/src/send/v2/mod.rs b/payjoin/src/send/v2/mod.rs index bc4db41d8..09656daaf 100644 --- a/payjoin/src/send/v2/mod.rs +++ b/payjoin/src/send/v2/mod.rs @@ -54,8 +54,8 @@ impl<'a> SenderBuilder<'a> { /// It is generally **not** recommended to set this as it may prevent the receiver from /// doing advanced operations such as opening LN channels and it also guarantees the /// receiver will **not** reward the sender with a discount. - pub fn always_disable_output_substitution(self, disable: bool) -> Self { - Self(self.0.always_disable_output_substitution(disable)) + pub fn always_disable_output_substitution(self) -> Self { + Self(self.0.always_disable_output_substitution()) } // Calculate the recommended fee contribution for an Original PSBT. @@ -150,7 +150,7 @@ impl Sender { .map_err(|_| InternalCreateRequestError::MissingOhttpConfig)?; let body = serialize_v2_body( &self.v1.psbt, - self.v1.disable_output_substitution, + self.v1.output_substitution, self.v1.fee_contribution, self.v1.min_fee_rate, )?; @@ -169,7 +169,7 @@ impl Sender { endpoint: self.v1.endpoint.clone(), psbt_ctx: PsbtContext { original_psbt: self.v1.psbt.clone(), - disable_output_substitution: self.v1.disable_output_substitution, + output_substitution: self.v1.output_substitution, fee_contribution: self.v1.fee_contribution, payee: self.v1.payee.clone(), min_fee_rate: self.v1.min_fee_rate, @@ -220,7 +220,7 @@ pub(crate) fn extract_request( pub(crate) fn serialize_v2_body( psbt: &Psbt, - disable_output_substitution: bool, + output_substitution: OutputSubstitution, fee_contribution: Option, min_fee_rate: FeeRate, ) -> Result, CreateRequestError> { @@ -229,7 +229,7 @@ pub(crate) fn serialize_v2_body( let placeholder_url = serialize_url( base_url, - disable_output_substitution, + output_substitution, fee_contribution, min_fee_rate, "2", // payjoin version @@ -365,7 +365,7 @@ mod test { v1: v1::Sender { psbt: PARSED_ORIGINAL_PSBT.clone(), endpoint, - disable_output_substitution: false, + output_substitution: OutputSubstitution::Enabled, fee_contribution: None, min_fee_rate: FeeRate::ZERO, payee: ScriptBuf::from(vec![0x00]), @@ -395,7 +395,7 @@ mod test { let sender = create_sender_context()?; let body = serialize_v2_body( &sender.v1.psbt, - sender.v1.disable_output_substitution, + sender.v1.output_substitution, sender.v1.fee_contribution, sender.v1.min_fee_rate, ); diff --git a/payjoin/src/uri/mod.rs b/payjoin/src/uri/mod.rs index 7e154f276..3f44e7f65 100644 --- a/payjoin/src/uri/mod.rs +++ b/payjoin/src/uri/mod.rs @@ -6,6 +6,7 @@ use url::Url; #[cfg(feature = "v2")] pub(crate) use crate::directory::ShortId; +use crate::output_substitution::OutputSubstitution; use crate::uri::error::InternalPjParseError; #[cfg(feature = "v2")] pub(crate) use crate::uri::url_ext::UrlExt; @@ -29,14 +30,18 @@ impl MaybePayjoinExtras { } } +/// Validated payjoin parameters #[derive(Debug, Clone)] pub struct PayjoinExtras { + /// pj parameter pub(crate) endpoint: Url, - pub(crate) disable_output_substitution: bool, + /// pjos parameter + pub(crate) output_substitution: OutputSubstitution, } impl PayjoinExtras { pub fn endpoint(&self) -> &Url { &self.endpoint } + pub fn output_substitution(&self) -> OutputSubstitution { self.output_substitution } } pub type Uri<'a, NetworkValidation> = bitcoin_uri::Uri<'a, NetworkValidation, MaybePayjoinExtras>; @@ -80,10 +85,6 @@ impl<'a> UriExt<'a> for Uri<'a, NetworkChecked> { } } -impl PayjoinExtras { - pub fn is_output_substitution_disabled(&self) -> bool { self.disable_output_substitution } -} - impl bitcoin_uri::de::DeserializationError for MaybePayjoinExtras { type Error = PjParseError; } @@ -95,7 +96,7 @@ impl bitcoin_uri::de::DeserializeParams<'_> for MaybePayjoinExtras { #[derive(Default)] pub struct DeserializationState { pj: Option, - pjos: Option, + pjos: Option, } impl bitcoin_uri::SerializeParams for &MaybePayjoinExtras { @@ -127,11 +128,12 @@ impl bitcoin_uri::SerializeParams for &PayjoinExtras { .replacen(scheme, &scheme.to_uppercase(), 1) .replacen(host, &host.to_uppercase(), 1); - vec![ - ("pjos", if self.disable_output_substitution { "1" } else { "0" }.to_string()), - ("pj", endpoint_str), - ] - .into_iter() + let mut params = Vec::with_capacity(2); + if self.output_substitution == OutputSubstitution::Disabled { + params.push(("pjos", String::from("0"))); + } + params.push(("pj", endpoint_str)); + params.into_iter() } } @@ -166,8 +168,8 @@ impl bitcoin_uri::de::DeserializationState<'_> for DeserializationState { "pj" => Err(InternalPjParseError::DuplicateParams("pj").into()), "pjos" if self.pjos.is_none() => { match &*Cow::try_from(value).map_err(|_| InternalPjParseError::BadPjOs)? { - "0" => self.pjos = Some(false), - "1" => self.pjos = Some(true), + "0" => self.pjos = Some(OutputSubstitution::Disabled), + "1" => self.pjos = Some(OutputSubstitution::Enabled), _ => return Err(InternalPjParseError::BadPjOs.into()), } Ok(bitcoin_uri::de::ParamKind::Known) @@ -191,7 +193,7 @@ impl bitcoin_uri::de::DeserializationState<'_> for DeserializationState { { Ok(MaybePayjoinExtras::Supported(PayjoinExtras { endpoint, - disable_output_substitution: pjos.unwrap_or(false), + output_substitution: pjos.unwrap_or(OutputSubstitution::Enabled), })) } else { Err(InternalPjParseError::UnsecureEndpoint.into()) @@ -272,4 +274,60 @@ mod tests { .extras .pj_is_supported()); } + + #[test] + fn test_serialize_pjos() { + let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pj=HTTPS://EXAMPLE.COM/%23OH1QYPM5JXYNS754Y4R45QWE336QFX6ZR8DQGVQCULVZTV20TFVEYDMFQC"; + let expected_is_disabled = "pjos=0"; + let expected_is_enabled = "pjos=1"; + let mut pjuri = Uri::try_from(uri) + .expect("Invalid uri") + .assume_checked() + .check_pj_supported() + .expect("Could not parse pj extras"); + + pjuri.extras.output_substitution = OutputSubstitution::Disabled; + assert!( + pjuri.to_string().contains(expected_is_disabled), + "Pj uri should contain param: {}, but it did not", + expected_is_disabled + ); + + pjuri.extras.output_substitution = OutputSubstitution::Enabled; + assert!( + !pjuri.to_string().contains(expected_is_enabled), + "Pj uri should elide param: {}, but it did not", + expected_is_enabled + ); + } + + #[test] + fn test_deserialize_pjos() { + // pjos=0 should disable output substitution + let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pj=https://example.com&pjos=0"; + let parsed = Uri::try_from(uri).unwrap(); + match parsed.extras { + MaybePayjoinExtras::Supported(extras) => + assert_eq!(extras.output_substitution, OutputSubstitution::Disabled), + _ => panic!("Expected Supported PayjoinExtras"), + } + + // pjos=1 should allow output substitution + let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pj=https://example.com&pjos=1"; + let parsed = Uri::try_from(uri).unwrap(); + match parsed.extras { + MaybePayjoinExtras::Supported(extras) => + assert_eq!(extras.output_substitution, OutputSubstitution::Enabled), + _ => panic!("Expected Supported PayjoinExtras"), + } + + // Elided pjos=1 should allow output substitution + let uri = "bitcoin:12c6DSiU4Rq3P4ZxziKxzrL5LmMBrzjrJX?pj=https://example.com"; + let parsed = Uri::try_from(uri).unwrap(); + match parsed.extras { + MaybePayjoinExtras::Supported(extras) => + assert_eq!(extras.output_substitution, OutputSubstitution::Enabled), + _ => panic!("Expected Supported PayjoinExtras"), + } + } } diff --git a/payjoin/tests/integration.rs b/payjoin/tests/integration.rs index 572bb77e7..fc3456d68 100644 --- a/payjoin/tests/integration.rs +++ b/payjoin/tests/integration.rs @@ -12,7 +12,7 @@ mod integration { use payjoin::receive::v1::build_v1_pj_uri; use payjoin::receive::ReplyableError::Implementation; use payjoin::receive::{ImplementationError, InputPair}; - use payjoin::{PjUri, Request, Uri}; + use payjoin::{OutputSubstitution, PjUri, Request, Uri}; use payjoin_test_utils::{init_bitcoind_sender_receiver, init_tracing, BoxError}; const EXAMPLE_URL: &str = "https://example.com"; @@ -72,7 +72,8 @@ mod integration { ) -> Result<(), BoxError> { // Receiver creates the payjoin URI let pj_receiver_address = receiver.get_new_address(None, None)?.assume_checked(); - let mut pj_uri = build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, false)?; + let mut pj_uri = + build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, OutputSubstitution::Enabled)?; pj_uri.amount = Some(Amount::ONE_BTC); // ********************** @@ -136,7 +137,8 @@ mod integration { // Receiver creates the payjoin URI let pj_receiver_address = receiver.get_new_address(None, None)?.assume_checked(); - let mut pj_uri = build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, false)?; + let mut pj_uri = + build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, OutputSubstitution::Enabled)?; pj_uri.amount = Some(Amount::ONE_BTC); // ********************** @@ -350,7 +352,6 @@ mod integration { .process_res(response.bytes().await?.to_vec().as_slice(), ctx)? .expect("proposal should exist"); let mut payjoin_proposal = handle_directory_proposal(&receiver, proposal, None)?; - assert!(!payjoin_proposal.is_output_substitution_disabled()); let (req, ctx) = payjoin_proposal.extract_v2_req(&ohttp_relay)?; let response = agent .post(req.url) @@ -402,7 +403,8 @@ mod integration { let (_bitcoind, sender, receiver) = init_bitcoind_sender_receiver(None, None)?; // Receiver creates the payjoin URI let pj_receiver_address = receiver.get_new_address(None, None)?.assume_checked(); - let mut pj_uri = build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, false)?; + let mut pj_uri = + build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, OutputSubstitution::Enabled)?; pj_uri.amount = Some(Amount::ONE_BTC); // ********************** @@ -529,7 +531,6 @@ mod integration { let mut payjoin_proposal = handle_directory_proposal(&receiver_clone, proposal, None) .map_err(|e| e.to_string())?; - assert!(payjoin_proposal.is_output_substitution_disabled()); // Respond with payjoin psbt within the time window the sender is willing to wait // this response would be returned as http response to the sender let (req, ctx) = payjoin_proposal.extract_v2_req(&ohttp_relay)?; @@ -988,7 +989,8 @@ mod integration { // Receiver creates the payjoin URI let pj_receiver_address = receiver.get_new_address(None, None)?.assume_checked(); - let mut pj_uri = build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, false)?; + let mut pj_uri = + build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, OutputSubstitution::Enabled)?; pj_uri.amount = Some(Amount::ONE_BTC); // ********************** @@ -1065,7 +1067,8 @@ mod integration { // Receiver creates the payjoin URI let pj_receiver_address = receiver.get_new_address(None, None)?.assume_checked(); - let mut pj_uri = build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, false)?; + let mut pj_uri = + build_v1_pj_uri(&pj_receiver_address, EXAMPLE_URL, OutputSubstitution::Enabled)?; pj_uri.amount = Some(Amount::ONE_BTC); // ********************** @@ -1186,7 +1189,6 @@ mod integration { )?; let proposal = handle_proposal(proposal, receiver, custom_outputs, drain_script, custom_inputs)?; - assert!(!proposal.is_output_substitution_disabled()); let psbt = proposal.psbt(); tracing::debug!("Receiver's Payjoin proposal PSBT: {:#?}", &psbt); Ok(psbt.to_string())